Skip to content

Commit

Permalink
remove had_counter from rand
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Dec 11, 2024
1 parent 8f4299f commit b063cd9
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,21 +495,18 @@ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, cont
# when using MOCKGPU and NV generate rand on CLANG
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"

# if shape has 0, return zero tensor
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
num = ceildiv(numel * dtype.itemsize, 4)

# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
had_counter = False
else: had_counter = True

# if shape has 0, return zero tensor
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
num = ceildiv(numel * dtype.itemsize, 4)

# increment rng counter for devices
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()

# threefry random bits
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
Expand Down

0 comments on commit b063cd9

Please sign in to comment.