Skip to content

Commit

Permalink
improvise for #15
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 13, 2024
1 parent 54b0824 commit ebddf2f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
1 change: 1 addition & 0 deletions mlp_mixer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from mlp_mixer_pytorch.mlp_mixer_pytorch import MLPMixer
from mlp_mixer_pytorch.mlp_mixer_3d_pytorch import MLPMixer3D
from mlp_mixer_pytorch.permutator import Permutator
44 changes: 44 additions & 0 deletions mlp_mixer_pytorch/mlp_mixer_3d_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

pair = lambda x: x if isinstance(x, tuple) else (x, x)

class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x):
return self.fn(self.norm(x)) + x

def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
inner_dim = int(dim * expansion_factor)
return nn.Sequential(
dense(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
dense(inner_dim, dim),
nn.Dropout(dropout)
)

def MLPMixer3D(*, image_size, time_size, channels, patch_size, time_patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
image_h, image_w = pair(image_size)
assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
assert (time_size % time_patch_size) == 0, 'time dimension must be divisible by time patch size'

num_patches = (image_h // patch_size) * (image_w // patch_size) * (time_size // time_patch_size)
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear

return nn.Sequential(
Rearrange('b c (t pt) (h p1) (w p2) -> b (h w t) (p1 p2 pt c)', p1 = patch_size, p2 = patch_size, pt = time_patch_size),
nn.Linear((time_patch_size * patch_size ** 2) * channels, dim),
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
Reduce('b n c -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'mlp-mixer-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.2.0',
license='MIT',
description = 'MLP Mixer - Pytorch',
author = 'Phil Wang',
Expand All @@ -15,8 +15,8 @@
'image recognition'
],
install_requires=[
'einops>=0.3',
'torch>=1.6'
'einops>=0.8',
'torch>=2.0'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit ebddf2f

Please sign in to comment.