Skip to content

Commit

Permalink
seed in tensor (tinygrad#6869)
Browse files Browse the repository at this point in the history
  • Loading branch information
wozeparrot authored Oct 6, 2024
1 parent f9e32f2 commit 9eb6eef
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 16 deletions.
14 changes: 8 additions & 6 deletions test/test_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,26 @@ def test_gc(self):
(a*b).mean().backward()
assert (tensors_allocated() > 0)
del a,b
assert (tensors_allocated() == 1) # one for Tensor._device_rng_counters
assert (tensors_allocated() == 2) # one for Tensor._device_rng_counters, and one for Tensor._device_seeds
Tensor.manual_seed(0)

def test_gc_complex(self):
Tensor.manual_seed(0)
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
b = Tensor.rand(4, 4, requires_grad=True)
assert (tensors_allocated() == 4)
(a*b).mean().backward()
assert (tensors_allocated() == 5)
(a*b).mean().backward()
assert (tensors_allocated() == 6)
del b
assert (tensors_allocated() == 3)
assert (tensors_allocated() == 4)
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())
assert (tensors_allocated() == 5)
assert (tensors_allocated() == 6)
del b
assert (tensors_allocated() == 3)
assert (tensors_allocated() == 4)
Tensor.manual_seed(0)

def test_schedule_gc(self):
init = bufs_allocated()
Expand Down
8 changes: 8 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ def f(a, b):
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"

def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()
Tensor.manual_seed(1234)
Tensor.rand()
res = [f().numpy() for _ in range(3)]
assert res[1] != res[2]

def test_jit_realization_and_sampling(self):
w = Tensor.eye(5)

Expand Down
33 changes: 29 additions & 4 deletions test/test_randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_rand_float16(self):
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))

@unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry")
def test_threefly_against_reference(self):
def test_threefry_against_reference(self):
Tensor.manual_seed(1337)

# reference generated using
Expand All @@ -92,11 +92,11 @@ def test_threefly_against_reference(self):

counts = Tensor.arange(20, dtype=dtypes.uint32)
counts0, counts1 = counts.chunk(2)
r = Tensor._threefry_random_bits(1337, 0, counts0, counts1).numpy()
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()

np.testing.assert_allclose(jr, r)

def test_threefly_against_reference_full(self):
def test_threefry_against_reference_full(self):
Tensor.manual_seed(1337)

# reference generated using
Expand All @@ -118,7 +118,7 @@ def test_threefly_against_reference_full(self):
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)

@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
def test_threefly_tensors_cnt(self):
def test_threefry_tensors_cnt(self):
Tensor.manual_seed(1337)

Tensor.rand(20).realize()
Expand All @@ -136,6 +136,31 @@ def test_threefly_tensors_cnt(self):
assert len(Tensor._device_rng_counters) == 0
assert len(Tensor._device_seeds) == 0

@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
def test_threefry_same_kernels(self):
Tensor.manual_seed(0)

Tensor.rand(1).realize()

s = Tensor.rand(20).schedule()
s2 = Tensor.rand(20).schedule()

assert len(s) == len(s2), f"{len(s)} != {len(s2)}"
for x,y in zip(s, s2):
if not (x.ast == y.ast):
print(f"{x.ast} != {y.ast}")

Tensor.rand(1, device=f"{Device.DEFAULT}:1").realize()

s3 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()
s4 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()

assert len(s3) == len(s4), f"{len(s3)} != {len(s4)}"
assert len(s2) == len(s4), f"{len(s)} != {len(s3)}"
for x,y in zip(s3, s4):
if not (x.ast == y.ast):
print(f"{x.ast} != {y.ast}")

@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
def test_rand_bfloat16(self):
N = 128
Expand Down
13 changes: 7 additions & 6 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
return r

_seed: int = int(time.time())
_device_seeds: Dict[str, int] = {}
_device_seeds: Dict[str, Tensor] = {}
_device_rng_counters: Dict[str, Tensor] = {}
@staticmethod
def manual_seed(seed=0):
Expand All @@ -447,9 +447,8 @@ def manual_seed(seed=0):
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}

@staticmethod
def _threefry_random_bits(key0, key1, counts0, counts1):
def _threefry_random_bits(key, counts0, counts1):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
key = (Tensor([key0], device=x.device, dtype=dtypes.uint64, requires_grad=False) << 32) | key1
x = F.Threefry.apply(*x._broadcasted(key))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
return counts0.cat(counts1)
Expand Down Expand Up @@ -478,7 +477,9 @@ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kw

# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
device=device, dtype=dtypes.uint64, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
had_counter = False
else: had_counter = True
Expand All @@ -487,12 +488,12 @@ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kw
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)

# increment rng counter for devices
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()

# threefry random bits
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
counts1 = counts0 + ceildiv(num, 2)
bits = Tensor._threefry_random_bits(Tensor._seed, Tensor._device_seeds[device], counts0, counts1)[:num]
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]

# bitcast to uint with same number of bits
_, nmant = dtypes.finfo(dtype)
Expand Down

0 comments on commit 9eb6eef

Please sign in to comment.