diff --git a/README.md b/README.md index 42bc714..2c85e6b 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,17 @@ This repo is a collection of PyTorch implementations of Transformer architectures with simple flexible config for ease of experimentation. The goal is learning and experimentation. +## Tests: + +Tests can be run using `pytest` from the root directory. There are also online colabs that should test any new architecture added to the repo on shakespeare character prediction. + +1. [basic transformer](https://colab.research.google.com/drive/1cNjbbiDqeHyjGFyMnuag9RykDKL2XKLp) +2. [MoE transformer](https://colab.research.google.com/drive/193oYMnTx8FdJDMj_NOgyOng6j9nQc7K_) + +As well as this each architecture and layer should be benchmarked for speed using: + +1. [Transformer-benchmarks](https://colab.research.google.com/drive/1hb9V6ne42awHTxKvcI0vct1SNEO7rock) + ## Resources: diff --git a/src/pytfex/transformer/make_model.py b/src/pytfex/transformer/make_model.py index a5c42cd..6b375f8 100644 --- a/src/pytfex/transformer/make_model.py +++ b/src/pytfex/transformer/make_model.py @@ -5,6 +5,7 @@ from pytfex.transformer.layer import TransformerLayer from pytfex.transformer.attention import Attention from pytfex.transformer.mlp import MLP +from pytfex.transformer.moe import MoE from pytfex.transformer.gpt import GPT from pytfex.transformer.heads import ClassificationHead, InversePatch from pytfex.transformer.embedders import TokenEmbedder, PositionEmbedder, \ @@ -16,6 +17,7 @@ class TransformerObjectRegistry: 'TransformerLayer': TransformerLayer, 'Attention': Attention, 'MLP': MLP, + 'MoE': MoE, 'GPT': GPT, 'ClassificationHead': ClassificationHead, 'InversePatch': InversePatch, @@ -23,7 +25,8 @@ class TransformerObjectRegistry: 'PositionEmbedder': PositionEmbedder, 'MultiEmbedder': MultiEmbedder, 'PatchEmbedder': PatchEmbedder, - 'LinearEmbedder': LinearEmbedder + 'LinearEmbedder': LinearEmbedder, + } def register(name): diff --git a/src/pytfex/transformer/mlp.py b/src/pytfex/transformer/mlp.py index f651d56..e9b8d00 100644 --- a/src/pytfex/transformer/mlp.py +++ b/src/pytfex/transformer/mlp.py @@ -5,10 +5,15 @@ class MLP(torch.nn.Module): def __init__( self, hidden_dim: int, + intermediate_dim: int=None, dropout: float=0.5, ): super(MLP, self).__init__() self.hidden_dim = hidden_dim + if intermediate_dim is None: + intermediate_dim = hidden_dim * 4 + self.intermediate_dim = intermediate_dim + self.dropout = torch.tensor( dropout, dtype=torch.float32 @@ -19,10 +24,10 @@ def __init__( self.linear1 = torch.nn.Linear( self.hidden_dim, - 4 * self.hidden_dim + self.intermediate_dim ) self.linear2 = torch.nn.Linear( - 4 * self.hidden_dim, + self.intermediate_dim, self.hidden_dim ) diff --git a/src/pytfex/transformer/moe.py b/src/pytfex/transformer/moe.py new file mode 100644 index 0000000..861e064 --- /dev/null +++ b/src/pytfex/transformer/moe.py @@ -0,0 +1,67 @@ +from typing import List +import torch + + +class MoE(torch.nn.Module): + def __init__( + self, + hidden_dim: int, + experts: List, + c: int = 2, + ): + """Mixture of Expert - expert choice routing layer + + See https://arxiv.org/pdf/2202.09368.pdf for more details. + + Args: + hidden_dim (int): hidden dimension + c (int, optional): Capacity of each expert. The capacity factor c denotes on average how + many experts are utilized by a token. Defaults to 2. + experts (List, optional): List of experts. Each expert is a torch.nn.Module. + """ + super(MoE, self).__init__() + self.hidden_dim = hidden_dim + self.c = c + self.experts = torch.nn.ModuleList(experts) + self.num_experts = len(self.experts) + self.gate = torch.nn.Linear( + hidden_dim, + self.num_experts, + bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim) + + Returns: + torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim) + """ + b, l, *_ = x.shape + k = self._compute_k(l) + S = torch.softmax(self.gate(x), dim=-1) + S = S.transpose(1, 2) # (batch_size, num_experts, tokens) + G, I = torch.topk(S, k, dim=-1) + # I - (batch_size, num_experts, top_k_tokens) - indices + # G - (batch_size, num_experts, top_k_tokens) - weights + new_x = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + indices = I[:, i] + scores = G[:, i] + batch_indices = (torch + .arange(b) + .view(-1, 1) + .expand_as(indices) + ) + # (batch_size, top_k_tokens, hidden_dim) - tokens for expert i + ex = x[batch_indices, indices] + ex_pred = scores[:, :, None] * expert(ex) + new_x[batch_indices, indices] += ex_pred + return x + + def _compute_k(self, l: int) -> int: + k = int((l * self.c) / self.num_experts) + k = min(max(k, 1), l) + return k \ No newline at end of file diff --git a/src/pytfex/transformer/tests/test_layers.py b/src/pytfex/transformer/tests/test_layers.py index 90760b4..4ad0642 100644 --- a/src/pytfex/transformer/tests/test_layers.py +++ b/src/pytfex/transformer/tests/test_layers.py @@ -1,5 +1,6 @@ from pytfex.transformer.attention import Attention from pytfex.transformer.mlp import MLP +from pytfex.transformer.moe import MoE import torch @@ -22,4 +23,20 @@ def test_MLP(): ) t1 = torch.zeros((1, 10, 12)) t2 = mlp(t1) - assert t2.shape == (1, 10, 12) \ No newline at end of file + assert t2.shape == (1, 10, 12) + + +def test_MoE_MLP(): + mlp = MoE( + hidden_dim=12, + c=2, + experts=[ + MLP( + hidden_dim=12, + dropout=0.5 + ) for _ in range(4) + ] + ) + t1 = torch.randn((2, 10, 12)) + t2 = mlp(t1) + assert t2.shape == (2, 10, 12) \ No newline at end of file diff --git a/src/pytfex/utils.py b/src/pytfex/utils.py index cbbb2af..0cfa306 100644 --- a/src/pytfex/utils.py +++ b/src/pytfex/utils.py @@ -21,4 +21,8 @@ def _parse_string_to_tuple(x): """For parsing config strings of the form: '1,2' to a tuple: (1,2)""" if not isinstance(x, str): return x - return tuple(map(int, x.split(','))) \ No newline at end of file + return tuple(map(int, x.split(','))) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) \ No newline at end of file diff --git a/src/tests/conftest.py b/src/tests/conftest.py index bc390d9..307f530 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -4,6 +4,7 @@ from pytfex.transformer.make_model import init_from_yml_string from pytfex.utils import set_seed from tests.basic_model import get_basic_gpt_config +from tests.moe_model import get_moe_gpt_config import torch @@ -11,25 +12,26 @@ @pytest.fixture(params=[ - (6, 3, 32, 'gpt-basic') + (256, 6, 3, 32, None, None, 'gpt-basic'), # (hdn_dim, length, num_digits, batch_size, _, _, model_type) + (256, 6, 3, 32, 2, 4, 'gpt-moe') # (hdn_dim, length, num_digits, batch_size, k, num_experts, model_type) ]) def training_setup(request): set_seed(0) - length, num_digits, batch_size, model_type = request.param + hdn_dim, length, num_digits, batch_size, c, num_experts, model_type = request.param ds = SortDataset(split='train', length=length, num_digits=num_digits) dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0) blk_size = ds.get_block_size() vcb_size = ds.get_vocab_size() - hdn_dim = 256 config = { - 'gpt-basic': get_basic_gpt_config(vcb_size, hdn_dim, blk_size) - }[model_type] + 'gpt-basic': lambda: get_basic_gpt_config(vcb_size, hdn_dim, blk_size), + 'gpt-moe': lambda: get_moe_gpt_config(vcb_size, hdn_dim, blk_size, c, num_experts) + }[model_type]() model = init_from_yml_string(config) def val_fn(model): - ds = SortDataset(split='test', length=6, num_digits=3) + ds = SortDataset(split='test', length=length, num_digits=num_digits) dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0) total = 0 sum_acc = 0 @@ -45,4 +47,4 @@ def val_fn(model): acc = sum_acc / total return acc - return dl, model, val_fn \ No newline at end of file + return dl, model, val_fn, model_type \ No newline at end of file diff --git a/src/tests/moe_model.py b/src/tests/moe_model.py new file mode 100644 index 0000000..f029764 --- /dev/null +++ b/src/tests/moe_model.py @@ -0,0 +1,56 @@ +def get_moe_gpt_config( + vcb_size, + hdn_dim, + blk_size, + num_experts, + c + ): + + return f""" + type: 'GPT' + params: + dropout: 0.5 + hidden_dim: {hdn_dim} + num_heads: 4 + dropout: 0.5 + embedder: + type: 'MultiEmbedder' + params: + embedders: + - type: 'TokenEmbedder' + params: + dictionary_size: {vcb_size} + hidden_dim: {hdn_dim} + - type: 'PositionEmbedder' + params: + num_positions: {blk_size} + hidden_dim: {hdn_dim} + layers: + - num: 2 + type: 'TransformerLayer' + params: + hidden_dim: {hdn_dim} + attn: + type: 'Attention' + params: + hidden_dim: {hdn_dim} + num_heads: 4 + dropout: 0.5 + mlp: + type: 'MoE' + params: + hidden_dim: {hdn_dim} + c: {c} + experts: + - num: {num_experts} + type: 'MLP' + params: + hidden_dim: {hdn_dim} + intermediate_dim: {hdn_dim} + dropout: 0.5 + head: + type: 'ClassificationHead' + params: + hidden_dim: {hdn_dim} + vocab_size: {vcb_size} + """ diff --git a/src/tests/test_model_init.py b/src/tests/test_model_init.py new file mode 100644 index 0000000..86577ce --- /dev/null +++ b/src/tests/test_model_init.py @@ -0,0 +1,21 @@ +from pytfex.utils import set_seed +from tests.basic_model import get_basic_gpt_config +from tests.moe_model import get_moe_gpt_config +from pytfex.transformer.make_model import init_from_yml_string + +import pytest + + +@pytest.mark.parametrize('vcb_size,hdn_dim,blk_size,k,num_experts,model_type', [ + (32, 12, 11, None, None, 'gpt-basic'), + (32, 12, 11, 2, 4, 'gpt-moe') +]) +def test_train(vcb_size, hdn_dim, blk_size, k, num_experts, model_type): + set_seed(0) + config = { + 'gpt-basic': get_basic_gpt_config(vcb_size, hdn_dim, blk_size), + 'gpt-moe': get_moe_gpt_config(vcb_size, hdn_dim, blk_size, k, num_experts) + }[model_type] + model = init_from_yml_string(config) + print(model) + assert model \ No newline at end of file diff --git a/src/tests/test_train.py b/src/tests/test_model_train.py similarity index 80% rename from src/tests/test_train.py rename to src/tests/test_model_train.py index f43b923..a1bcce7 100644 --- a/src/tests/test_train.py +++ b/src/tests/test_model_train.py @@ -1,4 +1,4 @@ -from pytfex.utils import set_seed +from pytfex.utils import set_seed, count_parameters import pytest import torch @@ -7,13 +7,15 @@ @pytest.mark.skip(reason="Slow running/intermittent test") def test_train(training_setup): set_seed(0) - dl, model, val_fn = training_setup + dl, model, val_fn, model_type = training_setup opt = torch.optim.Adam(model.parameters(), lr=1e-4) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) loss_fn = torch.nn.CrossEntropyLoss() acc = val_fn(model) print('\n') + print(f'-- model-type : {model_type} --') + print(f'-- # params : {count_parameters(model)} --') print('epoch_|_loss_____|_acc______') print(f' -1| None | {acc:0.5}') for epoch in range(5): @@ -32,6 +34,6 @@ def test_train(training_setup): acc = val_fn(model) print(f'{epoch:>6}| {loss.item():<8.5} | {acc:0.5}') - assert loss.item() < 0.1 + assert loss.item() < 0.15 acc = val_fn(model) assert acc > 0.95