Skip to content

Commit

Permalink
embedding shape (tinygrad#6930)
Browse files Browse the repository at this point in the history
  • Loading branch information
conwayz authored Oct 8, 2024
1 parent 20a9683 commit 08bfa86
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
9 changes: 9 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,15 @@ def test_embedding_one_kernel(self):
self.assertEqual(1, len([item for item in schedule if item.ast.op is UOps.SINK]), "second run realizes embedding only")
run_schedule(schedule)

def test_embedding_shape(self):
vocab_size, embed_size = 10, 16
layer = Embedding(vocab_size, embed_size)
for rank in range(5):
shp = (1,) * rank
a = Tensor([3]).reshape(shp)
result = layer(a)
self.assertEqual(result.shape, shp + (embed_size,))

def test_load_state_dict(self):
layer = Conv2d(3, 5, kernel_size=3)

Expand Down
4 changes: 2 additions & 2 deletions tinygrad/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ def __init__(self, vocab_size:int, embed_size:int):

def __call__(self, idx:Tensor) -> Tensor:
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)

class LSTMCell:
"""
Expand Down

0 comments on commit 08bfa86

Please sign in to comment.