Skip to content

Commit

Permalink
shard and to should preserve requires_grad (tinygrad#3224)
Browse files Browse the repository at this point in the history
dtypes are inferred from underlying lazydata, requires_grad needs to be passed explicitly
  • Loading branch information
chenyuxyz authored Jan 24, 2024
1 parent 23b084e commit 2f4b3ab
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
26 changes: 25 additions & 1 deletion test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import unittest, copy
import mmap
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import temp
from tinygrad.helpers import temp, CI
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat

settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")

x_init = np.random.randn(1,3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32)
Expand Down Expand Up @@ -315,6 +319,26 @@ def test_item_to_tensor_to_item(self):
assert type(reshaped_item) == type(a), a
np.testing.assert_allclose(reshaped_item, a), a

@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestMoveTensor(unittest.TestCase):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
@given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_to_preserves(self, src, dest, dtype, requires_grad):
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
t = s.to(dest)
np.testing.assert_equal(s.numpy(), t.numpy())
assert s.dtype == t.dtype
assert s.requires_grad == t.requires_grad

@given(strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_shard_preserves(self, dtype, requires_grad):
s = Tensor([1, 2, 3], dtype=dtype, requires_grad=requires_grad)
t = s.shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"))
np.testing.assert_equal(s.numpy(), t.numpy())
assert s.dtype == t.dtype
assert s.requires_grad == t.requires_grad

class TestZeroShapeTensor(unittest.TestCase):
def test_shape_stride(self):
t = Tensor.rand(3, 2, 0)
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def numpy(self) -> np.ndarray:
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
if device is None or device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.lazydata, device)
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
if self.grad: ret.grad = self.grad.to(device)
return ret

Expand All @@ -173,7 +173,7 @@ def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
if axis is not None and axis < 0: axis += len(self.shape)
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices)
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)

def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
self.lazydata = self.shard(devices, axis).lazydata
Expand Down

0 comments on commit 2f4b3ab

Please sign in to comment.