Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmarking scripts #7

Merged
merged 1 commit into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ venv
__pycache__
*.egg-info
/dist
/build
/build

*.prof
*.png
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
run_profiles:
python src/benchmarks/profiles.py
31 changes: 31 additions & 0 deletions src/benchmarks/profiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pytfex.utils import set_seed
from pytfex.models import get_model, GPTMoEConfig, GPTBasicConfig
import torch
from pytfex.utils import count_parameters
from profiling import Profiling


benchmarks = [
GPTBasicConfig(num_layers=1, hdn_dim=1024),
GPTMoEConfig(num_layers=1, num_experts=21, c=1, hdn_dim=512, ),
]

for config in benchmarks:
print(config)
set_seed(0)
model = get_model(config)
print(f'Number of parameters: {count_parameters(model)}')

t1 = torch.randint(0, config.vcb_size, (config.batch_size, config.blk_size))

model.eval()
output_1 = model(t1)
# with Profiling(model.layers[0].mlp) as p:
# output_2 = model(t1)

with Profiling(model) as p:
output_2 = model(t1)

assert torch.allclose(output_1, output_2), 'different outputs'

print(p)
117 changes: 117 additions & 0 deletions src/benchmarks/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import time
import string
import random

alphabet = string.ascii_lowercase + string.digits
def uuid(length=4):
return ''.join(random.choices(alphabet, k=length))

class Profiling(object):
def __init__(self, model):
if isinstance(model, torch.nn.Module) is False:
print("Not a valid model, please provide a 'nn.Module' instance.")

self.model = model
self.record = {
'forward':[],
}
self.profiling_on = True
self.origin_call = {}
self.hook_done = False

def __enter__(self):
self.start()
return self

def __exit__(self, *args):
self.stop()

def __str__(self):
ret = ""
ret += "\n================================= Profile =================================\n"
ret += "\nFORWARD TIME:\n"

ts = self.record['forward'][0][1]
te = self.record['forward'][-1][1]
ret += f"\nTotal time:\t{1000*(te - ts):.6f} ms\n\n"

ret += ('-------------------\n')
for i, ((name1, ts1, event1), (name2, ts2, event2)) in enumerate(zip(
self.record['forward'],
self.record['forward'][1:]
)):
ret += (
f"event{i+1:3d}:\t{1000*(ts2 - ts1):10.6f} ms"
f"\t({event1}:{name1} -> {event2}:{name2})\n"
)

ret += ('-------------------\n')
component_time = 0
for name, ts1, ts2 in self.component_events:
diff = ts2 - ts1
ret += (f"{1000*(diff):0.6f} ms \t ({name}) \n")
component_time += diff

ret += ('-------------------\n')
ret += (f"{1000*(component_time):0.6f} ms \t (total-component-time) \n")
ret += (f"{1000*(te - ts - component_time):0.6f} ms \t (others) \n")

return ret

def start(self):
if self.hook_done is False:
self.hook_done = True
self.hook_modules(self.model, self.model.__class__.__name__)
self.profiling_on = True
return self

@property
def component_events(self):
comp_data = {}
component_names = []
for component_name, ts, event in self.record['forward']:
if component_name not in comp_data:
comp_data[component_name] = {}
component_names.append(component_name)
comp_data[component_name][event] = ts

for component_name in component_names:
yield (
component_name,
comp_data[component_name]['start'],
comp_data[component_name]['end'],
)

def stop(self):
self.profiling_on = False
return self

def hook_modules(self, module, name):
for name, layer in module.named_children():
if isinstance(layer, torch.nn.ModuleList):
for ind, sub_sub_module in enumerate(layer):
self._hook_module(f'{name}-{ind}', sub_sub_module)
else:
self._hook_module(name, layer)

def _hook_module(self, name, layer):
uid = uuid(length=4)
name = name + '-' + uid
def make_hook(event):
def hook(layer, *args, **kwargs):
t = time.time()
if (self.profiling_on):
self.record['forward'].append(
(name, t, event)
)

return hook

layer.register_forward_hook(
make_hook('end')
)
layer.register_forward_pre_hook(
make_hook('start')
)

34 changes: 34 additions & 0 deletions src/pytfex/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from pytfex.models.moe import get_moe_gpt_config
from pytfex.models.basic import get_basic_gpt_config
from pytfex.transformer.make_model import init_from_yml_string
from dataclasses import dataclass


@dataclass
class GPTMoEConfig:
model_type: str = 'gpt-moe'
vcb_size: int = 65
hdn_dim: int = 256
blk_size: int = 256
c: int = 2
num_experts: int = 4
batch_size: int = 32
num_layers: int = 2


@dataclass
class GPTBasicConfig:
model_type: str = 'gpt-basic'
vcb_size: int = 65
hdn_dim: int = 256
blk_size: int = 256
batch_size: int = 32
num_layers: int = 2


def get_model(config):
config_str = {
'gpt-moe': get_moe_gpt_config,
'gpt-basic': get_basic_gpt_config,
}[config.model_type](config)
return init_from_yml_string(config_str)
43 changes: 43 additions & 0 deletions src/pytfex/models/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
def get_basic_gpt_config(config):

return f"""
type: 'GPT'
params:
dropout: 0.5
hidden_dim: {config.hdn_dim}
num_heads: 4
dropout: 0.5
embedder:
type: 'MultiEmbedder'
params:
embedders:
- type: 'TokenEmbedder'
params:
dictionary_size: {config.vcb_size}
hidden_dim: {config.hdn_dim}
- type: 'PositionEmbedder'
params:
num_positions: {config.blk_size}
hidden_dim: {config.hdn_dim}
head:
type: 'ClassificationHead'
params:
hidden_dim: {config.hdn_dim}
vocab_size: {config.vcb_size}
layers:
- num: {config.num_layers}
type: 'TransformerLayer'
params:
hidden_dim: {config.hdn_dim}
attn:
type: 'Attention'
params:
hidden_dim: {config.hdn_dim}
num_heads: 4
dropout: 0.5
mlp:
type: 'MLP'
params:
hidden_dim: {config.hdn_dim}
dropout: 0.5
"""
49 changes: 49 additions & 0 deletions src/pytfex/models/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
def get_moe_gpt_config(config):
return f"""
type: 'GPT'
params:
dropout: 0.5
hidden_dim: {config.hdn_dim}
num_heads: 4
dropout: 0.5
embedder:
type: 'MultiEmbedder'
params:
embedders:
- type: 'TokenEmbedder'
params:
dictionary_size: {config.vcb_size}
hidden_dim: {config.hdn_dim}
- type: 'PositionEmbedder'
params:
num_positions: {config.blk_size}
hidden_dim: {config.hdn_dim}
layers:
- num: {config.num_layers}
type: 'TransformerLayer'
params:
hidden_dim: {config.hdn_dim}
attn:
type: 'Attention'
params:
hidden_dim: {config.hdn_dim}
num_heads: 4
dropout: 0.5
mlp:
type: 'MoE'
params:
hidden_dim: {config.hdn_dim}
c: {config.c}
experts:
- num: {config.num_experts}
type: 'MLP'
params:
hidden_dim: {config.hdn_dim}
intermediate_dim: {4*config.hdn_dim}
dropout: 0.5
head:
type: 'ClassificationHead'
params:
hidden_dim: {config.hdn_dim}
vocab_size: {config.vcb_size}
"""
1 change: 1 addition & 0 deletions src/pytfex/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def __init__(
dropout: float=0.5,
) -> None:
super(Attention, self).__init__()
assert hidden_dim % num_heads == 0, f"num_heads must divide hidden_dim, {hidden_dim=}, {num_heads=}"
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.dropout = torch.tensor(
Expand Down
1 change: 0 additions & 1 deletion src/pytfex/transformer/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def __init__(
head: torch.nn.Module=None,
):
super(GPT, self).__init__()
assert hidden_dim % num_heads == 0, "num_heads must divide hidden_dim"
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.dropout = dropout
Expand Down
2 changes: 1 addition & 1 deletion src/pytfex/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
b, l, *_ = x.shape
k = self._compute_k(l)
S = torch.softmax(self.gate(x), dim=-1)
S = torch.sigmoid(self.gate(x))
S = S.transpose(1, 2) # (batch_size, num_experts, tokens)
G, I = torch.topk(S, k, dim=-1)
# I - (batch_size, num_experts, top_k_tokens) - indices
Expand Down
50 changes: 30 additions & 20 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
from torch.utils.data.dataloader import DataLoader
from tests.dataset import SortDataset

from pytfex.transformer.make_model import init_from_yml_string
from pytfex.utils import set_seed
from tests.basic_model import get_basic_gpt_config
from tests.moe_model import get_moe_gpt_config

import torch
from pytfex.models import (
get_model,
GPTMoEConfig,
GPTBasicConfig,
)

import torch
import pytest


@pytest.fixture(params=[
(256, 6, 3, 32, None, None, 'gpt-basic'), # (hdn_dim, length, num_digits, batch_size, _, _, model_type)
(256, 6, 3, 32, 2, 4, 'gpt-moe') # (hdn_dim, length, num_digits, batch_size, k, num_experts, model_type)
# (model_type, hdn_dim, length, num_digits, batch_size, _, _, _)
(GPTBasicConfig(
model_type='gpt-basic',
vcb_size=3,
hdn_dim=256,
blk_size=12,
batch_size=32,
), 6),
# (model_type, hdn_dim, length, num_digits, batch_size, k, num_experts, _)
(GPTMoEConfig(
model_type='gpt-moe',
vcb_size=3,
hdn_dim=256,
blk_size=12,
c=2,
num_experts=4,
batch_size=32,
), 6)
])
def training_setup(request):
set_seed(0)

hdn_dim, length, num_digits, batch_size, c, num_experts, model_type = request.param
config, length = request.param
num_digits = config.vcb_size
ds = SortDataset(split='train', length=length, num_digits=num_digits)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)
blk_size = ds.get_block_size()
vcb_size = ds.get_vocab_size()

config = {
'gpt-basic': lambda: get_basic_gpt_config(vcb_size, hdn_dim, blk_size),
'gpt-moe': lambda: get_moe_gpt_config(vcb_size, hdn_dim, blk_size, c, num_experts)
}[model_type]()
model = init_from_yml_string(config)
dl = DataLoader(ds, batch_size=config.batch_size, shuffle=True, num_workers=0)
model = get_model(config)

def val_fn(model):
ds = SortDataset(split='test', length=length, num_digits=num_digits)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
dl = DataLoader(ds, batch_size=config.batch_size, shuffle=False, num_workers=0)
total = 0
sum_acc = 0
for x, y_true in dl:
Expand All @@ -47,4 +57,4 @@ def val_fn(model):
acc = sum_acc / total
return acc

return dl, model, val_fn, model_type
return dl, model, val_fn, config.model_type
Loading