Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 16, 2024
1 parent 65692b6 commit 191cdbe
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'slot_attention',
packages = find_packages(),
version = '1.2.0',
version = '1.2.1',
license='MIT',
description = 'Implementation of Slot Attention in Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
17 changes: 10 additions & 7 deletions slot_attention/multi_head_slot_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import torch
from torch import einsum, nn
from torch.nn import init
Expand Down Expand Up @@ -54,7 +56,11 @@ def __init__(
nn.Linear(hidden_dim, dim)
)

def forward(self, inputs, num_slots = None):
def forward(
self,
inputs,
num_slots: int | None = None
):
b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
n_s = num_slots if num_slots is not None else self.num_slots

Expand All @@ -77,9 +83,9 @@ def forward(self, inputs, num_slots = None):
q = self.split_heads(q)

dots = torch.einsum('... i d, ... j d -> ... i j', q, k) * self.scale
attn = dots.softmax(dim = -2) + self.eps

attn = attn / attn.sum(dim = -1, keepdim = True)
attn = dots.softmax(dim = -2)
attn = F.normalize(attn, p = 1, dim = -1, eps = self.eps)

updates = einsum('... j d, ... i j -> ... i d', v, attn)
updates = self.merge_heads(updates)
Expand All @@ -88,10 +94,7 @@ def forward(self, inputs, num_slots = None):
updates, packed_shape = pack([updates], '* d')
slots_prev, _ = pack([slots_prev], '* d')

slots = self.gru(
updates,
slots_prev
)
slots = self.gru(updates, slots_prev)

slots, = unpack(slots, packed_shape, '* d')
slots = slots + self.mlp(self.norm_pre_ff(slots))
Expand Down

0 comments on commit 191cdbe

Please sign in to comment.