From 54b08248956464f4361127715738194a1d0d92d5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 16 Feb 2022 21:12:59 -0800 Subject: [PATCH] support rectangular images --- README.md | 19 +++++++++++++++++++ mlp_mixer_pytorch/mlp_mixer_pytorch.py | 7 +++++-- setup.py | 2 +- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c0ad5a9..174eaed 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,25 @@ img = torch.randn(1, 3, 256, 256) pred = model(img) # (1, 1000) ``` +Rectangular image + +```python +import torch +from mlp_mixer_pytorch import MLPMixer + +model = MLPMixer( + image_size = (256, 128), + channels = 3, + patch_size = 16, + dim = 512, + depth = 12, + num_classes = 1000 +) + +img = torch.randn(1, 3, 256, 128) +pred = model(img) # (1, 1000) +``` + ## Citations ```bibtex diff --git a/mlp_mixer_pytorch/mlp_mixer_pytorch.py b/mlp_mixer_pytorch/mlp_mixer_pytorch.py index df9f680..2c6ec6f 100644 --- a/mlp_mixer_pytorch/mlp_mixer_pytorch.py +++ b/mlp_mixer_pytorch/mlp_mixer_pytorch.py @@ -2,6 +2,8 @@ from functools import partial from einops.layers.torch import Rearrange, Reduce +pair = lambda x: x if isinstance(x, tuple) else (x, x) + class PreNormResidual(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -22,8 +24,9 @@ def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear): ) def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.): - assert (image_size % patch_size) == 0, 'image must be divisible by patch size' - num_patches = (image_size // patch_size) ** 2 + image_h, image_w = pair(image_size) + assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size' + num_patches = (image_h // patch_size) * (image_w // patch_size) chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear return nn.Sequential( diff --git a/setup.py b/setup.py index 4de558f..4c83ce0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'mlp-mixer-pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'MLP Mixer - Pytorch', author = 'Phil Wang',