-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_naive.py
42 lines (36 loc) · 1.36 KB
/
train_naive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pytorch_lightning as pl
from models.naive import argparser
from models.naive.model import Model
from models.naive.data_module import DataModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from models.naive.config import GPUS,ACCELERATOR
from copy import deepcopy
args = argparser.get_args()
if __name__ == "__main__":
trainer = pl.Trainer(
gpus=GPUS,
accelerator=ACCELERATOR,
fast_dev_run=args.dev,
precision=32,
default_root_dir='.log_naive',
max_epochs=args.epoch,
callbacks=[
# EarlyStopping(monitor='dev_loss',patience=3),
ModelCheckpoint(monitor='dev_loss',filename='{epoch}-{dev_loss:.2f}',save_last=True),
]
)
dm = DataModule()
if args.from_checkpoint is None:
model = Model()
else:
print('load from checkpoint')
model = Model.load_from_checkpoint(args.from_checkpoint)
# train
if args.run_test == False:
tuner = pl.tuner.tuning.Tuner(deepcopy(trainer))
new_batch_size = tuner.scale_batch_size(model, datamodule=dm)
del tuner
model.hparams.batch_size = new_batch_size
trainer.fit(model,datamodule=dm)
trainer.test(model if args.run_test else None,datamodule=dm,ckpt_path=None if args.dev else 'best')