-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathtrain.py
67 lines (48 loc) · 1.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import glob
import os
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from trainer import Trainer
import utils.logging
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/train_nansy.yaml")
parser.add_argument('-g', '--gpus', type=str,
help="")
parser.add_argument('-p', '--resume_checkpoint_path', type=str, default=None,
help="path of checkpoint for resuming")
args = parser.parse_args()
return args
def main():
args = parse_args()
conf = OmegaConf.load(args.config)
conf.logging.log_dir = os.path.join(conf.logging.log_dir, str(conf.logging.seed))
os.makedirs(conf.logging.log_dir, exist_ok=True)
save_file_dir = os.path.join(conf.logging.log_dir, 'code')
os.makedirs(save_file_dir, exist_ok=True)
savefiles = []
for reg in conf.logging.save_files:
savefiles += glob.glob(reg)
utils.logging.save_files(save_file_dir, savefiles)
checkpoint_dir = os.path.join(conf.logging.log_dir, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir, **conf.pl.checkpoint.callback)
tensorboard_dir = os.path.join(conf.logging.log_dir, 'tensorboard')
os.makedirs(tensorboard_dir, exist_ok=True)
logger = TensorBoardLogger(tensorboard_dir)
logger.log_hyperparams(conf)
trainer = pl.Trainer(
logger=logger,
gpus=args.gpus,
callbacks=[checkpoint_callback],
weights_save_path=checkpoint_dir,
resume_from_checkpoint=args.resume_checkpoint_path,
**conf.pl.trainer
)
model = Trainer(conf) # TODO get trainer from conf
trainer.fit(model)
if __name__ == '__main__':
main()