From e6c71f1b26d6a6a13c98a8c7228585ee692d3e38 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 21 Jan 2024 21:03:50 -0500 Subject: [PATCH] fix device of Tensor.arange inside Tensor.one_hot (#3199) it should have the same device as self --- test/test_tensor.py | 7 +++++++ tinygrad/tensor.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 615f5d0bffbbe..f6d35e074628a 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -437,5 +437,12 @@ def test_reduce_default(self): np.testing.assert_equal(Tensor([]).sum().numpy(), 0) np.testing.assert_equal(Tensor([]).mean().numpy(), 0) +class TestTensorCreationDevice(unittest.TestCase): + # test auxiliary tensors are created on the same device + def test_one_hot(self): + y = Tensor([1, 2, 3]).to("CPU") + x = y.one_hot(10) + x.realize() + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 67e20f1683103..a6225ec4d4ecb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -901,7 +901,8 @@ def dropout(self, p=0.5) -> Tensor: if not Tensor.training or p == 0: return self return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p)) - def one_hot(self, num_classes:int, **kwargs) -> Tensor: return Tensor.where(self[..., None] == Tensor.arange(num_classes), 1, 0, **kwargs) + def one_hot(self, num_classes:int) -> Tensor: + return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501 # NOTE: it works if key, value have symbolic shape