-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmain.py
56 lines (46 loc) · 1.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from pytorch_lightning import Trainer
import logging
import sys
sys.path.append("./models")
sys.path.append("./utils")
from interface import *
from utils import *
from vae import VAE
from beta_vae import betaVAE
def main(args):
""" main() driver function """
# Parameters parsing
if filepath_is_not_valid(args.config):
logging.error("The path {} is not a file. Aborting..".format(args.config))
exit()
configuration, architecture, hyperparameters = parse_config_file(args.config, args.variation)
dataset_info = prepare_dataset(configuration)
if (dataset_info is None):
exit()
# Initialization
model = None
if (args.variation == "VAE"):
model = VAE(architecture, hyperparameters, dataset_info)
elif (args.variation == "B-VAE"):
model = betaVAE(architecture, hyperparameters, dataset_info)
# here you can change the gpus parameter into the amount of gpus you want the model to use
trainer = Trainer(max_epochs = hyperparameters["epochs"], gpus=None, fast_dev_run=False)
# Training and testing
trainer.fit(model)
result = trainer.test(model)
# Model needs to be transferred to the cpu as sample and reconstruct are custom methods
model = model.cpu()
model.sample(5)
model.reconstruct(5)
if __name__ == "__main__":
""" call main() function here """
print()
# configure the level of the logging and the format of the messages
logging.basicConfig(level=logging.ERROR, format="%(levelname)s: %(message)s\n")
# parse the command line input
args = parse_cmd_args()
# call the main() driver function
main(args)
print("\n")