Skip to content

Commit

Permalink
add dropout and diagram
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 5, 2021
1 parent 0f658a9 commit 285b3bc
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<img src="./images/mlp-mixer.png" width="500px"></img>

## MLP Mixer - Pytorch

An All-MLP solution for Vision, from Google AI, in Pytorch.
Expand Down
Binary file added mlp-mixer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 7 additions & 5 deletions mlp_mixer_pytorch/mlp_mixer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ def __init__(self, dim, fn):
def forward(self, x):
return self.fn(self.norm(x)) + x

def FeedForward(dim, expansion_factor = 4):
def FeedForward(dim, expansion_factor = 4, dropout = 0.):
return nn.Sequential(
nn.Linear(dim, dim * expansion_factor),
nn.GELU(),
nn.Linear(dim * expansion_factor, dim)
nn.Dropout(dropout),
nn.Linear(dim * expansion_factor, dim),
nn.Dropout(dropout)
)

def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4):
def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
num_patches = (image_size // patch_size) ** 2

Expand All @@ -27,10 +29,10 @@ def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_facto
*[nn.Sequential(
PreNormResidual(dim, nn.Sequential(
Rearrange('b n c -> b c n'),
FeedForward(num_patches, expansion_factor),
FeedForward(num_patches, expansion_factor, dropout),
Rearrange('b c n -> b n c'),
)),
PreNormResidual(dim, FeedForward(dim, expansion_factor))
PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout))
) for _ in range(depth)],
nn.LayerNorm(dim),
Rearrange('b n c -> b c n'),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'mlp-mixer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'MLP Mixer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 285b3bc

Please sign in to comment.