From 2f4b3ab1c015f470e9a50e4f6949e712a198797a Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 24 Jan 2024 00:15:10 -0500 Subject: [PATCH] shard and to should preserve requires_grad (#3224) dtypes are inferred from underlying lazydata, requires_grad needs to be passed explicitly --- test/test_tensor.py | 26 +++++++++++++++++++++++++- tinygrad/tensor.py | 4 ++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index f6d35e074628a..a44b35bd17f9d 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -3,8 +3,12 @@ import unittest, copy import mmap from tinygrad import Tensor, Device, dtypes -from tinygrad.helpers import temp +from tinygrad.helpers import temp, CI from extra.gradcheck import numerical_jacobian, jacobian, gradcheck +from hypothesis import given, settings, strategies as strat + +settings.register_profile("my_profile", max_examples=200, deadline=None) +settings.load_profile("my_profile") x_init = np.random.randn(1,3).astype(np.float32) U_init = np.random.randn(3,3).astype(np.float32) @@ -315,6 +319,26 @@ def test_item_to_tensor_to_item(self): assert type(reshaped_item) == type(a), a np.testing.assert_allclose(reshaped_item, a), a +@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") +class TestMoveTensor(unittest.TestCase): + d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1" + @given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]), + strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None])) + def test_to_preserves(self, src, dest, dtype, requires_grad): + s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad) + t = s.to(dest) + np.testing.assert_equal(s.numpy(), t.numpy()) + assert s.dtype == t.dtype + assert s.requires_grad == t.requires_grad + + @given(strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None])) + def test_shard_preserves(self, dtype, requires_grad): + s = Tensor([1, 2, 3], dtype=dtype, requires_grad=requires_grad) + t = s.shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")) + np.testing.assert_equal(s.numpy(), t.numpy()) + assert s.dtype == t.dtype + assert s.requires_grad == t.requires_grad + class TestZeroShapeTensor(unittest.TestCase): def test_shape_stride(self): t = Tensor.rand(3, 2, 0) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a6225ec4d4ecb..208104a026f66 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -159,7 +159,7 @@ def numpy(self) -> np.ndarray: def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor: if device is None or device == self.device: return self if not isinstance(device, str): return self.shard(device) - ret = Tensor(self.lazydata, device) + ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad) if self.grad: ret.grad = self.grad.to(device) return ret @@ -173,7 +173,7 @@ def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor: assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer" canonical_devices = tuple(Device.canonicalize(x) for x in devices) if axis is not None and axis < 0: axis += len(self.shape) - return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices) + return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad) def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None): self.lazydata = self.shard(devices, axis).lazydata