Skip to content

Commit

Permalink
Add moe layers (#3)
Browse files Browse the repository at this point in the history
* Add MoE layers

* Add count_parameters util

* Add experiment links in README.md
  • Loading branch information
mauicv authored Dec 20, 2023
1 parent d77d384 commit ecd9bc7
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 15 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@

This repo is a collection of PyTorch implementations of Transformer architectures with simple flexible config for ease of experimentation. The goal is learning and experimentation.

## Tests:

Tests can be run using `pytest` from the root directory. There are also online colabs that should test any new architecture added to the repo on shakespeare character prediction.

1. [basic transformer](https://colab.research.google.com/drive/1cNjbbiDqeHyjGFyMnuag9RykDKL2XKLp)
2. [MoE transformer](https://colab.research.google.com/drive/193oYMnTx8FdJDMj_NOgyOng6j9nQc7K_)

As well as this each architecture and layer should be benchmarked for speed using:

1. [Transformer-benchmarks](https://colab.research.google.com/drive/1hb9V6ne42awHTxKvcI0vct1SNEO7rock)


## Resources:

Expand Down
5 changes: 4 additions & 1 deletion src/pytfex/transformer/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytfex.transformer.layer import TransformerLayer
from pytfex.transformer.attention import Attention
from pytfex.transformer.mlp import MLP
from pytfex.transformer.moe import MoE
from pytfex.transformer.gpt import GPT
from pytfex.transformer.heads import ClassificationHead, InversePatch
from pytfex.transformer.embedders import TokenEmbedder, PositionEmbedder, \
Expand All @@ -16,14 +17,16 @@ class TransformerObjectRegistry:
'TransformerLayer': TransformerLayer,
'Attention': Attention,
'MLP': MLP,
'MoE': MoE,
'GPT': GPT,
'ClassificationHead': ClassificationHead,
'InversePatch': InversePatch,
'TokenEmbedder': TokenEmbedder,
'PositionEmbedder': PositionEmbedder,
'MultiEmbedder': MultiEmbedder,
'PatchEmbedder': PatchEmbedder,
'LinearEmbedder': LinearEmbedder
'LinearEmbedder': LinearEmbedder,

}

def register(name):
Expand Down
9 changes: 7 additions & 2 deletions src/pytfex/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@ class MLP(torch.nn.Module):
def __init__(
self,
hidden_dim: int,
intermediate_dim: int=None,
dropout: float=0.5,
):
super(MLP, self).__init__()
self.hidden_dim = hidden_dim
if intermediate_dim is None:
intermediate_dim = hidden_dim * 4
self.intermediate_dim = intermediate_dim

self.dropout = torch.tensor(
dropout,
dtype=torch.float32
Expand All @@ -19,10 +24,10 @@ def __init__(

self.linear1 = torch.nn.Linear(
self.hidden_dim,
4 * self.hidden_dim
self.intermediate_dim
)
self.linear2 = torch.nn.Linear(
4 * self.hidden_dim,
self.intermediate_dim,
self.hidden_dim
)

Expand Down
67 changes: 67 additions & 0 deletions src/pytfex/transformer/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import List
import torch


class MoE(torch.nn.Module):
def __init__(
self,
hidden_dim: int,
experts: List,
c: int = 2,
):
"""Mixture of Expert - expert choice routing layer
See https://arxiv.org/pdf/2202.09368.pdf for more details.
Args:
hidden_dim (int): hidden dimension
c (int, optional): Capacity of each expert. The capacity factor c denotes on average how
many experts are utilized by a token. Defaults to 2.
experts (List, optional): List of experts. Each expert is a torch.nn.Module.
"""
super(MoE, self).__init__()
self.hidden_dim = hidden_dim
self.c = c
self.experts = torch.nn.ModuleList(experts)
self.num_experts = len(self.experts)
self.gate = torch.nn.Linear(
hidden_dim,
self.num_experts,
bias=False
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim)
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim)
"""
b, l, *_ = x.shape
k = self._compute_k(l)
S = torch.softmax(self.gate(x), dim=-1)
S = S.transpose(1, 2) # (batch_size, num_experts, tokens)
G, I = torch.topk(S, k, dim=-1)
# I - (batch_size, num_experts, top_k_tokens) - indices
# G - (batch_size, num_experts, top_k_tokens) - weights
new_x = torch.zeros_like(x)
for i, expert in enumerate(self.experts):
indices = I[:, i]
scores = G[:, i]
batch_indices = (torch
.arange(b)
.view(-1, 1)
.expand_as(indices)
)
# (batch_size, top_k_tokens, hidden_dim) - tokens for expert i
ex = x[batch_indices, indices]
ex_pred = scores[:, :, None] * expert(ex)
new_x[batch_indices, indices] += ex_pred
return x

def _compute_k(self, l: int) -> int:
k = int((l * self.c) / self.num_experts)
k = min(max(k, 1), l)
return k
19 changes: 18 additions & 1 deletion src/pytfex/transformer/tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pytfex.transformer.attention import Attention
from pytfex.transformer.mlp import MLP
from pytfex.transformer.moe import MoE
import torch


Expand All @@ -22,4 +23,20 @@ def test_MLP():
)
t1 = torch.zeros((1, 10, 12))
t2 = mlp(t1)
assert t2.shape == (1, 10, 12)
assert t2.shape == (1, 10, 12)


def test_MoE_MLP():
mlp = MoE(
hidden_dim=12,
c=2,
experts=[
MLP(
hidden_dim=12,
dropout=0.5
) for _ in range(4)
]
)
t1 = torch.randn((2, 10, 12))
t2 = mlp(t1)
assert t2.shape == (2, 10, 12)
6 changes: 5 additions & 1 deletion src/pytfex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ def _parse_string_to_tuple(x):
"""For parsing config strings of the form: '1,2' to a tuple: (1,2)"""
if not isinstance(x, str):
return x
return tuple(map(int, x.split(',')))
return tuple(map(int, x.split(',')))


def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
16 changes: 9 additions & 7 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,34 @@
from pytfex.transformer.make_model import init_from_yml_string
from pytfex.utils import set_seed
from tests.basic_model import get_basic_gpt_config
from tests.moe_model import get_moe_gpt_config

import torch

import pytest


@pytest.fixture(params=[
(6, 3, 32, 'gpt-basic')
(256, 6, 3, 32, None, None, 'gpt-basic'), # (hdn_dim, length, num_digits, batch_size, _, _, model_type)
(256, 6, 3, 32, 2, 4, 'gpt-moe') # (hdn_dim, length, num_digits, batch_size, k, num_experts, model_type)
])
def training_setup(request):
set_seed(0)

length, num_digits, batch_size, model_type = request.param
hdn_dim, length, num_digits, batch_size, c, num_experts, model_type = request.param
ds = SortDataset(split='train', length=length, num_digits=num_digits)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)
blk_size = ds.get_block_size()
vcb_size = ds.get_vocab_size()
hdn_dim = 256

config = {
'gpt-basic': get_basic_gpt_config(vcb_size, hdn_dim, blk_size)
}[model_type]
'gpt-basic': lambda: get_basic_gpt_config(vcb_size, hdn_dim, blk_size),
'gpt-moe': lambda: get_moe_gpt_config(vcb_size, hdn_dim, blk_size, c, num_experts)
}[model_type]()
model = init_from_yml_string(config)

def val_fn(model):
ds = SortDataset(split='test', length=6, num_digits=3)
ds = SortDataset(split='test', length=length, num_digits=num_digits)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
total = 0
sum_acc = 0
Expand All @@ -45,4 +47,4 @@ def val_fn(model):
acc = sum_acc / total
return acc

return dl, model, val_fn
return dl, model, val_fn, model_type
56 changes: 56 additions & 0 deletions src/tests/moe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
def get_moe_gpt_config(
vcb_size,
hdn_dim,
blk_size,
num_experts,
c
):

return f"""
type: 'GPT'
params:
dropout: 0.5
hidden_dim: {hdn_dim}
num_heads: 4
dropout: 0.5
embedder:
type: 'MultiEmbedder'
params:
embedders:
- type: 'TokenEmbedder'
params:
dictionary_size: {vcb_size}
hidden_dim: {hdn_dim}
- type: 'PositionEmbedder'
params:
num_positions: {blk_size}
hidden_dim: {hdn_dim}
layers:
- num: 2
type: 'TransformerLayer'
params:
hidden_dim: {hdn_dim}
attn:
type: 'Attention'
params:
hidden_dim: {hdn_dim}
num_heads: 4
dropout: 0.5
mlp:
type: 'MoE'
params:
hidden_dim: {hdn_dim}
c: {c}
experts:
- num: {num_experts}
type: 'MLP'
params:
hidden_dim: {hdn_dim}
intermediate_dim: {hdn_dim}
dropout: 0.5
head:
type: 'ClassificationHead'
params:
hidden_dim: {hdn_dim}
vocab_size: {vcb_size}
"""
21 changes: 21 additions & 0 deletions src/tests/test_model_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pytfex.utils import set_seed
from tests.basic_model import get_basic_gpt_config
from tests.moe_model import get_moe_gpt_config
from pytfex.transformer.make_model import init_from_yml_string

import pytest


@pytest.mark.parametrize('vcb_size,hdn_dim,blk_size,k,num_experts,model_type', [
(32, 12, 11, None, None, 'gpt-basic'),
(32, 12, 11, 2, 4, 'gpt-moe')
])
def test_train(vcb_size, hdn_dim, blk_size, k, num_experts, model_type):
set_seed(0)
config = {
'gpt-basic': get_basic_gpt_config(vcb_size, hdn_dim, blk_size),
'gpt-moe': get_moe_gpt_config(vcb_size, hdn_dim, blk_size, k, num_experts)
}[model_type]
model = init_from_yml_string(config)
print(model)
assert model
8 changes: 5 additions & 3 deletions src/tests/test_train.py → src/tests/test_model_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pytfex.utils import set_seed
from pytfex.utils import set_seed, count_parameters

import pytest
import torch
Expand All @@ -7,13 +7,15 @@
@pytest.mark.skip(reason="Slow running/intermittent test")
def test_train(training_setup):
set_seed(0)
dl, model, val_fn = training_setup
dl, model, val_fn, model_type = training_setup
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
loss_fn = torch.nn.CrossEntropyLoss()
acc = val_fn(model)

print('\n')
print(f'-- model-type : {model_type} --')
print(f'-- # params : {count_parameters(model)} --')
print('epoch_|_loss_____|_acc______')
print(f' -1| None | {acc:0.5}')
for epoch in range(5):
Expand All @@ -32,6 +34,6 @@ def test_train(training_setup):
acc = val_fn(model)
print(f'{epoch:>6}| {loss.item():<8.5} | {acc:0.5}')

assert loss.item() < 0.1
assert loss.item() < 0.15
acc = val_fn(model)
assert acc > 0.95

0 comments on commit ecd9bc7

Please sign in to comment.