-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
112 lines (102 loc) · 4.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
import torch
from helmnet import IterativeSolver, load_settings
import os
from argparse import ArgumentParser
if __name__ == "__main__":
# Parsing command line arguments
parser = ArgumentParser()
parser.add_argument(
"--accelerator",
type=str,
default="ddp",
help="Distributed training backend, see https://pytorch.org/tutorials/intermediate/ddp_tutorial.html.",
)
parser.add_argument(
"--gpus",
type=str,
default="2,3,4,5,6,7",
help="IDs of the GPUs to use during training, separated by a comma.",
)
parser.add_argument(
"--precision",
type=int,
default="32",
help="Bits precision to use for calculations, can be either 32 or 16.", #16 bit is volatile. when i tried it, NaNs everywhere all the time
)
parser.add_argument(
"--max_epochs",
type=int,
default=1000,
help="Number of total epochs for training.",
)
parser.add_argument(
"--parameters",
type=str,
default="experiments/base.json",
help="Path to json file setting parameters for training.",
)
parser.add_argument("--track_arg_norm", type=bool, default=True)
parser.add_argument("--terminate_on_nan", type=bool, default=True)
parser.add_argument("--check_val_every_n_epoch", type=int, default=2)
parser.add_argument("--limit_val_batches", type=float, default=1.0)
parser.add_argument("--num_sanity_val_steps", type=int, default=1)
parser.add_argument("--benchmark", type=bool, default=True)
# Parse input arguments
args = parser.parse_args()
# Loading setings file
settings = load_settings(args.parameters)
# Making model
solver = IterativeSolver(
batch_size = settings["training"]["train batch size"],
domain_size = settings["geometry"]["grid size"],
k = settings["source"]["omega"] / settings["medium"]["c0"],
omega = settings["source"]["omega"],
gradient_clip_val = settings["training"]["gradient clipping"],
learning_rate = settings["training"]["learning rate"],
loss = settings["training"]["loss"],
minimum_learning_rate = settings["training"]["minimum learning rate"],
optimizer = settings["training"]["optimizer"],
PMLsize = settings["geometry"]["PML Size"],
sigma_max = settings["geometry"]["sigma max"],
source_location = settings["source"]["location"],
source_amplitude = settings["source"]["amplitude"],
source_phase = settings["source"]["phase"],
source_smoothing = settings["source"]["smoothing"],
train_data_path = settings["medium"]["train_set"],
validation_data_path = settings["medium"]["validation_set"],
test_data_path = settings["medium"]["test_set"],
activation_function = settings["neural_network"]["activation function"],
depth = settings["neural_network"]["depth"],
features = settings["neural_network"]["channels per layer"],
max_iterations = settings["environment"]["max iterations"],
state_channels = settings["neural_network"]["state channels"],
state_depth = settings["neural_network"]["states depth"],
weight_decay = settings["training"]["weight_decay"],
buffer_size= settings["training"]["buffer size"]
)
# Create trainer
logger = TensorBoardLogger("logs", name="helmnet")
checkpoint_callback = ModelCheckpoint(
dirpath = os.getcwd() + "/checkpoints/",
save_top_k = 3,
verbose = True,
monitor = "val_loss",
mode = "min",
save_last = True,
)
# parser = pl.Trainer.add_argparse_args(parser)
gpu_list = [int(i) for i in args.gpus.split(',')]
# Make trainer
trainer = pl.Trainer.from_argparse_args(
args,
gpus = gpu_list,
logger = logger,
callbacks = [checkpoint_callback],
#plugins = DDPPlugin(find_unused_parameters = False), #causes errors when not running ddp
)
# Train network
trainer.fit(solver)