diff --git a/test/test_nn.py b/test/test_nn.py index d598e26c07a6..c50d3a441e6c 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -468,6 +468,15 @@ def test_embedding_one_kernel(self): self.assertEqual(1, len([item for item in schedule if item.ast.op is UOps.SINK]), "second run realizes embedding only") run_schedule(schedule) + def test_embedding_shape(self): + vocab_size, embed_size = 10, 16 + layer = Embedding(vocab_size, embed_size) + for rank in range(5): + shp = (1,) * rank + a = Tensor([3]).reshape(shp) + result = layer(a) + self.assertEqual(result.shape, shp + (embed_size,)) + def test_load_state_dict(self): layer = Conv2d(3, 5, kernel_size=3) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index b488682ca973..309d90bc5bd2 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -313,10 +313,10 @@ def __init__(self, vocab_size:int, embed_size:int): def __call__(self, idx:Tensor) -> Tensor: if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device) - arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,) + arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,) if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp) - return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype) + return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype) class LSTMCell: """