Skip to content

Commit

Permalink
support rectangular images
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 17, 2022
1 parent 2151af6 commit 54b0824
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)
```

Rectangular image

```python
import torch
from mlp_mixer_pytorch import MLPMixer

model = MLPMixer(
image_size = (256, 128),
channels = 3,
patch_size = 16,
dim = 512,
depth = 12,
num_classes = 1000
)

img = torch.randn(1, 3, 256, 128)
pred = model(img) # (1, 1000)
```

## Citations

```bibtex
Expand Down
7 changes: 5 additions & 2 deletions mlp_mixer_pytorch/mlp_mixer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from einops.layers.torch import Rearrange, Reduce

pair = lambda x: x if isinstance(x, tuple) else (x, x)

class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
Expand All @@ -22,8 +24,9 @@ def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
)

def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
num_patches = (image_size // patch_size) ** 2
image_h, image_w = pair(image_size)
assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
num_patches = (image_h // patch_size) * (image_w // patch_size)
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear

return nn.Sequential(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'mlp-mixer-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'MLP Mixer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 54b0824

Please sign in to comment.