Skip to content

Commit

Permalink
dropout contiguous after >= p (tinygrad#6892)
Browse files Browse the repository at this point in the history
make it a bool buffer
  • Loading branch information
chenyuxyz authored Oct 6, 2024
1 parent 9eb6eef commit 999e378
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def _threefry_random_bits(key, counts0, counts1):
return counts0.cat(counts1)

@staticmethod
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
Expand Down Expand Up @@ -509,7 +509,7 @@ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kw
if getenv("MOCKGPU") and _device: out = out.to(_device)

out.requires_grad = kwargs.get("requires_grad")
return out.contiguous()
return out.contiguous() if contiguous else out

# ***** creation helper functions *****

Expand Down Expand Up @@ -673,13 +673,15 @@ def rand_like(self, **kwargs) -> Tensor:
"""
dtype = kwargs.pop("dtype", self.dtype)
device = kwargs.pop("device", self.device)
contiguous = kwargs.pop("contiguous", True)
if isinstance(self.device, tuple):
assert isinstance(self.lazydata, MultiLazyBuffer)
if self.lazydata.axis is not None:
rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype).lazydata) for lb in self.lazydata.lbs]
rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata) \
for lb in self.lazydata.lbs]
return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
return Tensor.rand(*self.shape, device=device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, dtype=dtype, contiguous=contiguous, **kwargs).shard(self.device)
return Tensor.rand(*self.shape, device=device, dtype=dtype, contiguous=contiguous, **kwargs)

# ***** rng hlops *****

Expand Down Expand Up @@ -3131,7 +3133,7 @@ def dropout(self, p=0.5) -> Tensor:
```
"""
if not Tensor.training or p == 0: return self
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p))
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)

def one_hot(self, num_classes:int=-1) -> Tensor:
"""
Expand Down

0 comments on commit 999e378

Please sign in to comment.