-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
320 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,7 @@ venv | |
__pycache__ | ||
*.egg-info | ||
/dist | ||
/build | ||
/build | ||
|
||
*.prof | ||
*.png |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
run_profiles: | ||
python src/benchmarks/profiles.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pytfex.utils import set_seed | ||
from pytfex.models import get_model, GPTMoEConfig, GPTBasicConfig | ||
import torch | ||
from pytfex.utils import count_parameters | ||
from profiling import Profiling | ||
|
||
|
||
benchmarks = [ | ||
GPTBasicConfig(num_layers=1, hdn_dim=1024), | ||
GPTMoEConfig(num_layers=1, num_experts=21, c=1, hdn_dim=512, ), | ||
] | ||
|
||
for config in benchmarks: | ||
print(config) | ||
set_seed(0) | ||
model = get_model(config) | ||
print(f'Number of parameters: {count_parameters(model)}') | ||
|
||
t1 = torch.randint(0, config.vcb_size, (config.batch_size, config.blk_size)) | ||
|
||
model.eval() | ||
output_1 = model(t1) | ||
# with Profiling(model.layers[0].mlp) as p: | ||
# output_2 = model(t1) | ||
|
||
with Profiling(model) as p: | ||
output_2 = model(t1) | ||
|
||
assert torch.allclose(output_1, output_2), 'different outputs' | ||
|
||
print(p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import torch | ||
import time | ||
import string | ||
import random | ||
|
||
alphabet = string.ascii_lowercase + string.digits | ||
def uuid(length=4): | ||
return ''.join(random.choices(alphabet, k=length)) | ||
|
||
class Profiling(object): | ||
def __init__(self, model): | ||
if isinstance(model, torch.nn.Module) is False: | ||
print("Not a valid model, please provide a 'nn.Module' instance.") | ||
|
||
self.model = model | ||
self.record = { | ||
'forward':[], | ||
} | ||
self.profiling_on = True | ||
self.origin_call = {} | ||
self.hook_done = False | ||
|
||
def __enter__(self): | ||
self.start() | ||
return self | ||
|
||
def __exit__(self, *args): | ||
self.stop() | ||
|
||
def __str__(self): | ||
ret = "" | ||
ret += "\n================================= Profile =================================\n" | ||
ret += "\nFORWARD TIME:\n" | ||
|
||
ts = self.record['forward'][0][1] | ||
te = self.record['forward'][-1][1] | ||
ret += f"\nTotal time:\t{1000*(te - ts):.6f} ms\n\n" | ||
|
||
ret += ('-------------------\n') | ||
for i, ((name1, ts1, event1), (name2, ts2, event2)) in enumerate(zip( | ||
self.record['forward'], | ||
self.record['forward'][1:] | ||
)): | ||
ret += ( | ||
f"event{i+1:3d}:\t{1000*(ts2 - ts1):10.6f} ms" | ||
f"\t({event1}:{name1} -> {event2}:{name2})\n" | ||
) | ||
|
||
ret += ('-------------------\n') | ||
component_time = 0 | ||
for name, ts1, ts2 in self.component_events: | ||
diff = ts2 - ts1 | ||
ret += (f"{1000*(diff):0.6f} ms \t ({name}) \n") | ||
component_time += diff | ||
|
||
ret += ('-------------------\n') | ||
ret += (f"{1000*(component_time):0.6f} ms \t (total-component-time) \n") | ||
ret += (f"{1000*(te - ts - component_time):0.6f} ms \t (others) \n") | ||
|
||
return ret | ||
|
||
def start(self): | ||
if self.hook_done is False: | ||
self.hook_done = True | ||
self.hook_modules(self.model, self.model.__class__.__name__) | ||
self.profiling_on = True | ||
return self | ||
|
||
@property | ||
def component_events(self): | ||
comp_data = {} | ||
component_names = [] | ||
for component_name, ts, event in self.record['forward']: | ||
if component_name not in comp_data: | ||
comp_data[component_name] = {} | ||
component_names.append(component_name) | ||
comp_data[component_name][event] = ts | ||
|
||
for component_name in component_names: | ||
yield ( | ||
component_name, | ||
comp_data[component_name]['start'], | ||
comp_data[component_name]['end'], | ||
) | ||
|
||
def stop(self): | ||
self.profiling_on = False | ||
return self | ||
|
||
def hook_modules(self, module, name): | ||
for name, layer in module.named_children(): | ||
if isinstance(layer, torch.nn.ModuleList): | ||
for ind, sub_sub_module in enumerate(layer): | ||
self._hook_module(f'{name}-{ind}', sub_sub_module) | ||
else: | ||
self._hook_module(name, layer) | ||
|
||
def _hook_module(self, name, layer): | ||
uid = uuid(length=4) | ||
name = name + '-' + uid | ||
def make_hook(event): | ||
def hook(layer, *args, **kwargs): | ||
t = time.time() | ||
if (self.profiling_on): | ||
self.record['forward'].append( | ||
(name, t, event) | ||
) | ||
|
||
return hook | ||
|
||
layer.register_forward_hook( | ||
make_hook('end') | ||
) | ||
layer.register_forward_pre_hook( | ||
make_hook('start') | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from pytfex.models.moe import get_moe_gpt_config | ||
from pytfex.models.basic import get_basic_gpt_config | ||
from pytfex.transformer.make_model import init_from_yml_string | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class GPTMoEConfig: | ||
model_type: str = 'gpt-moe' | ||
vcb_size: int = 65 | ||
hdn_dim: int = 256 | ||
blk_size: int = 256 | ||
c: int = 2 | ||
num_experts: int = 4 | ||
batch_size: int = 32 | ||
num_layers: int = 2 | ||
|
||
|
||
@dataclass | ||
class GPTBasicConfig: | ||
model_type: str = 'gpt-basic' | ||
vcb_size: int = 65 | ||
hdn_dim: int = 256 | ||
blk_size: int = 256 | ||
batch_size: int = 32 | ||
num_layers: int = 2 | ||
|
||
|
||
def get_model(config): | ||
config_str = { | ||
'gpt-moe': get_moe_gpt_config, | ||
'gpt-basic': get_basic_gpt_config, | ||
}[config.model_type](config) | ||
return init_from_yml_string(config_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
def get_basic_gpt_config(config): | ||
|
||
return f""" | ||
type: 'GPT' | ||
params: | ||
dropout: 0.5 | ||
hidden_dim: {config.hdn_dim} | ||
num_heads: 4 | ||
dropout: 0.5 | ||
embedder: | ||
type: 'MultiEmbedder' | ||
params: | ||
embedders: | ||
- type: 'TokenEmbedder' | ||
params: | ||
dictionary_size: {config.vcb_size} | ||
hidden_dim: {config.hdn_dim} | ||
- type: 'PositionEmbedder' | ||
params: | ||
num_positions: {config.blk_size} | ||
hidden_dim: {config.hdn_dim} | ||
head: | ||
type: 'ClassificationHead' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
vocab_size: {config.vcb_size} | ||
layers: | ||
- num: {config.num_layers} | ||
type: 'TransformerLayer' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
attn: | ||
type: 'Attention' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
num_heads: 4 | ||
dropout: 0.5 | ||
mlp: | ||
type: 'MLP' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
dropout: 0.5 | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
def get_moe_gpt_config(config): | ||
return f""" | ||
type: 'GPT' | ||
params: | ||
dropout: 0.5 | ||
hidden_dim: {config.hdn_dim} | ||
num_heads: 4 | ||
dropout: 0.5 | ||
embedder: | ||
type: 'MultiEmbedder' | ||
params: | ||
embedders: | ||
- type: 'TokenEmbedder' | ||
params: | ||
dictionary_size: {config.vcb_size} | ||
hidden_dim: {config.hdn_dim} | ||
- type: 'PositionEmbedder' | ||
params: | ||
num_positions: {config.blk_size} | ||
hidden_dim: {config.hdn_dim} | ||
layers: | ||
- num: {config.num_layers} | ||
type: 'TransformerLayer' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
attn: | ||
type: 'Attention' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
num_heads: 4 | ||
dropout: 0.5 | ||
mlp: | ||
type: 'MoE' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
c: {config.c} | ||
experts: | ||
- num: {config.num_experts} | ||
type: 'MLP' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
intermediate_dim: {4*config.hdn_dim} | ||
dropout: 0.5 | ||
head: | ||
type: 'ClassificationHead' | ||
params: | ||
hidden_dim: {config.hdn_dim} | ||
vocab_size: {config.vcb_size} | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.