Skip to content

Commit

Permalink
add adaptive slot wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 20, 2024
1 parent 1bdcc37 commit e66a1c4
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 10 deletions.
49 changes: 45 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,56 @@ After training, the network is reported to be able to generalize to slightly dif
slot_attn(inputs, num_slots = 8) # (2, 8, 512)
```

To use the <a href="https://arxiv.org/abs/2406.09196">adaptive slot</a> method for generating a differentiable one hot mask for whether to use a slot, just do the following

```python
import torch
from slot_attention import MultiHeadSlotAttention, AdaptiveSlotWrapper

# define slot attention

slot_attn = MultiHeadSlotAttention(
dim = 512,
num_slots = 5,
iters = 3,
)

# wrap the slot attention

adaptive_slots = AdaptiveSlotWrapper(
slot_attn,
temperature = 0.5 # gumbel softmax temperature
)

inputs = torch.randn(2, 1024, 512)

slots, keep_slots = adaptive_slots(inputs) # (2, 5, 512), (2, 5)

# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be

keep_aux_loss = keep_slots.sum() # add this to your main loss with some weight
```

## Citation

```bibtex
@misc{locatello2020objectcentric,
title = {Object-Centric Learning with Slot Attention},
author = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
year = {2020},
eprint = {2006.15055},
title = {Object-Centric Learning with Slot Attention},
author = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
year = {2020},
eprint = {2006.15055},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

```bibtex
@article{Fan2024AdaptiveSA,
title = {Adaptive Slot Attention: Object Discovery with Dynamic Slot Number},
author = {Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.09196},
url = {https://api.semanticscholar.org/CorpusID:270440447}
}
```
12 changes: 6 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'slot_attention',
packages = find_packages(),
version = '1.2.2',
version = '1.4.0',
license='MIT',
description = 'Implementation of Slot Attention in Pytorch',
long_description_content_type = 'text/markdown',
Expand All @@ -16,10 +16,10 @@
'torch'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
2 changes: 2 additions & 0 deletions slot_attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from slot_attention.slot_attention_experimental import SlotAttentionExperimental

from slot_attention.multi_head_slot_attention import MultiHeadSlotAttention

from slot_attention.adaptive_slot_wrapper import AdaptiveSlotWrapper
79 changes: 79 additions & 0 deletions slot_attention/adaptive_slot_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F

from slot_attention.slot_attention import SlotAttention
from slot_attention.multi_head_slot_attention import MultiHeadSlotAttention

# functions

def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
noise = torch.rand_like(t)
return -log(-log(noise))

def gumbel_softmax(logits, temperature = 1.):
dtype, size = logits.dtype, logits.shape[-1]

assert temperature > 0

scaled_logits = logits / temperature

# gumbel sampling and derive one hot

noised_logits = scaled_logits + gumbel_noise(scaled_logits)

indices = noised_logits.argmax(dim = -1)

hard_one_hot = F.one_hot(indices, size).type(dtype)

# get soft for gradients

soft = scaled_logits.softmax(dim = -1)

# straight through

hard_one_hot = hard_one_hot + soft - soft.detach()

# return indices and one hot

return hard_one_hot, indices

# wrapper

class AdaptiveSlotWrapper(Module):
def __init__(
self,
slot_attn: SlotAttention | MultiHeadSlotAttention,
temperature = 1.
):
super().__init__()

self.slot_attn = slot_attn
dim = slot_attn.dim

self.temperature = temperature
self.pred_keep_slot = nn.Linear(dim, 2, bias = False)

def forward(
self,
x,
**slot_kwargs
):

slots = self.slot_attn(x, **slot_kwargs)

keep_slot_logits = self.pred_keep_slot(slots)

keep_slots, _ = gumbel_softmax(keep_slot_logits, temperature = self.temperature)

# just use last column for "keep" mask

keep_slots = keep_slots[..., -1] # Float["batch num_slots"] of {0., 1.}

return slots, keep_slots
2 changes: 2 additions & 0 deletions slot_attention/multi_head_slot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import einsum, nn
from torch.nn import init
from torch.nn import Module
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
Expand All @@ -20,6 +21,7 @@ def __init__(
hidden_dim = 128
):
super().__init__()
self.dim = dim
self.num_slots = num_slots
self.iters = iters
self.eps = eps
Expand Down
1 change: 1 addition & 0 deletions slot_attention/slot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class SlotAttention(nn.Module):
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
super().__init__()
self.dim = dim
self.num_slots = num_slots
self.iters = iters
self.eps = eps
Expand Down

0 comments on commit e66a1c4

Please sign in to comment.