Skip to content

Commit

Permalink
fix device of Tensor.arange inside Tensor.one_hot (tinygrad#3199)
Browse files Browse the repository at this point in the history
it should have the same device as self
  • Loading branch information
chenyuxyz authored Jan 22, 2024
1 parent f7d1c42 commit e6c71f1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
7 changes: 7 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,5 +437,12 @@ def test_reduce_default(self):
np.testing.assert_equal(Tensor([]).sum().numpy(), 0)
np.testing.assert_equal(Tensor([]).mean().numpy(), 0)

class TestTensorCreationDevice(unittest.TestCase):
# test auxiliary tensors are created on the same device
def test_one_hot(self):
y = Tensor([1, 2, 3]).to("CPU")
x = y.one_hot(10)
x.realize()

if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,8 @@ def dropout(self, p=0.5) -> Tensor:
if not Tensor.training or p == 0: return self
return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))

def one_hot(self, num_classes:int, **kwargs) -> Tensor: return Tensor.where(self[..., None] == Tensor.arange(num_classes), 1, 0, **kwargs)
def one_hot(self, num_classes:int) -> Tensor:
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)

def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501
# NOTE: it works if key, value have symbolic shape
Expand Down

0 comments on commit e6c71f1

Please sign in to comment.