-
Notifications
You must be signed in to change notification settings - Fork 19
/
main_train.py
72 lines (68 loc) · 3.59 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
import wandb
import argparse
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from cleanformer.models import Transformer
from cleanformer.fetchers import fetch_tokenizer, fetch_config
from cleanformer.paths import ROOT_DIR
from cleanformer.datamodules import Kor2EngDataModule, Kor2EngSmallDataModule
def main():
parser = argparse.ArgumentParser()
parser.add_argument("entity", type=str)
parser.add_argument("--ver", type=str, default="overfit_small")
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
parser.add_argument("--log_every_n_steps", type=int, default=1)
parser.add_argument("--fast_dev_run", action="store_true", default=False)
parser.add_argument("--overfit_batches", type=int, default=0)
parser.add_argument("--check_val_every_n_epoch", type=int, default=5)
args = parser.parse_args()
config = fetch_config()['train'][args.ver]
config.update(vars(args))
# --- fetch a pre-trained tokenizer from wandb -- #
tokenizer = fetch_tokenizer(config['entity'], config['tokenizer'])
# --- instantiate the transformer to train --- #
transformer = Transformer(config['hidden_size'],
config['ffn_size'],
tokenizer.get_vocab_size(), # vocab_size
config['max_length'],
tokenizer.pad_token_id, # noqa
config['heads'],
config['depth'],
config['dropout'],
config['lr'])
# --- choose the data (either the full version, or a smaller version) --- #
if config['data'] == Kor2EngDataModule.name:
datamodule = Kor2EngDataModule(config, tokenizer)
elif config['data'] == Kor2EngSmallDataModule.name:
datamodule = Kor2EngSmallDataModule(config, tokenizer)
else:
raise ValueError(f"Invalid data: {config['data']}")
# --- start wandb context --- #
with wandb.init(entity=config['entity'], project="cleanformer", config=config) as run:
# --- prepare a logger (wandb) and a trainer to use --- #
logger = WandbLogger(log_model=False)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
trainer = Trainer(fast_dev_run=config['fast_dev_run'],
check_val_every_n_epoch=config['check_val_every_n_epoch'],
overfit_batches=config['overfit_batches'],
max_epochs=config['max_epochs'],
log_every_n_steps=config['log_every_n_steps'],
gpus=torch.cuda.device_count(),
callbacks=[lr_monitor],
enable_checkpointing=False,
logger=logger)
# --- start training --- #
trainer.fit(model=transformer, datamodule=datamodule)
# --- upload the model to wandb only if the training is properly done --- #
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
ckpt_path = ROOT_DIR / "transformer.ckpt"
trainer.save_checkpoint(str(ckpt_path))
artifact = wandb.Artifact(name="transformer", type="model", metadata=config)
artifact.add_file(str(ckpt_path))
run.log_artifact(artifact, aliases=["latest", config['ver']])
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
if __name__ == '__main__':
main()