Skip to content

Commit

Permalink
tiny rand cleanup
Browse files Browse the repository at this point in the history
remove `had_counter` and combine two blocks
  • Loading branch information
chenyuxyz committed Oct 19, 2024
1 parent f511ad9 commit 78f85b1
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,23 +469,21 @@ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, cont
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
_device = device = Device.canonicalize(device)

# if shape has 0, return zero tensor
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)

# when using MOCKGPU and NV generate rand on CLANG
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"

# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
# generate per device seeds and rng counter if we haven't seen this device yet
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
device=device, dtype=dtypes.uint64, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
had_counter = False
else: had_counter = True

# if shape has 0, return zero tensor
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)

# increment rng counter for devices
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
else:
# have seen this device, increment its rng counter
Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()

# threefry random bits
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
Expand Down

0 comments on commit 78f85b1

Please sign in to comment.