From 9e7a87f9e03a3a2afa6399a7795fea9319f5b92c Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 4 Oct 2024 22:07:54 +0800 Subject: [PATCH] dropout contiguous after >= p make it a bool buffer --- tinygrad/tensor.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e12462631995..86432564c7d2 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -454,7 +454,7 @@ def _threefry_random_bits(key0, key1, counts0, counts1): return counts0.cat(counts1) @staticmethod - def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor: + def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`. @@ -507,7 +507,7 @@ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kw if getenv("MOCKGPU") and _device: out = out.to(_device) out.requires_grad = kwargs.get("requires_grad") - return out.contiguous() + return out.contiguous() if contiguous else out # ***** creation helper functions ***** @@ -671,13 +671,15 @@ def rand_like(self, **kwargs) -> Tensor: """ dtype = kwargs.pop("dtype", self.dtype) device = kwargs.pop("device", self.device) + contiguous = kwargs.pop("contiguous", True) if isinstance(self.device, tuple): assert isinstance(self.lazydata, MultiLazyBuffer) if self.lazydata.axis is not None: - rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype).lazydata) for lb in self.lazydata.lbs] + rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata) \ + for lb in self.lazydata.lbs] return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) - return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) - return Tensor.rand(*self.shape, device=device, dtype=dtype, **kwargs) + return Tensor.rand(*self.shape, dtype=dtype, contiguous=contiguous, **kwargs).shard(self.device) + return Tensor.rand(*self.shape, device=device, dtype=dtype, contiguous=contiguous, **kwargs) # ***** rng hlops ***** @@ -3101,7 +3103,7 @@ def dropout(self, p=0.5) -> Tensor: ``` """ if not Tensor.training or p == 0: return self - return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p)) + return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p) def one_hot(self, num_classes:int=-1) -> Tensor: """