-
Notifications
You must be signed in to change notification settings - Fork 22
/
main.py
executable file
·111 lines (89 loc) · 3.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import torch.utils.tensorboard as tb
import copy
from runners.rs256_guided_diffusion import Diffusion
def parse_args_and_config():
parser = argparse.ArgumentParser(description=globals()['__doc__'])
parser.add_argument('--config', type=str, required=True, help='Path to the config file')
parser.add_argument('--seed', type=int, default=1234, help='Random seed')
parser.add_argument('--repeat_run', type=int, default=1, help='Repeat run')
parser.add_argument('--sample_step', type=int, default=1, help='Total sampling steps')
parser.add_argument('--t', type=int, default=400, help='Sampling noise scale')
parser.add_argument('--r', dest='reverse_steps', type=int, default=20, help='Revserse steps')
parser.add_argument('--comment', type=str, default='', help='Comment')
args = parser.parse_args()
# parse config file
with open(os.path.join('configs', args.config), 'r') as f:
config = yaml.safe_load(f)
config = dict2namespace(config)
os.makedirs(config.log_dir, exist_ok=True)
if config.model.type == 'conditional':
dir_name = 'recons_{}_t{}_r{}_w{}'.format(config.data.data_kw,
args.t, args.reverse_steps,
config.sampling.guidance_weight)
else:
dir_name = 'recons_{}_t{}_r{}_lam{}'.format(config.data.data_kw,
args.t, args.reverse_steps,
config.sampling.lambda_)
if config.model.type == 'conditional':
print('Use residual gradient guidance during sampling')
dir_name = 'guided_' + dir_name
elif config.sampling.lambda_ > 0:
print('Use residual gradient penalty during sampling')
dir_name = 'pi_' + dir_name
else:
print('Not use physical gradient during sampling')
log_dir = os.path.join(config.log_dir, dir_name)
os.makedirs(log_dir, exist_ok=True)
with open(os.path.join(log_dir, 'config.yml'), 'w') as outfile:
yaml.dump(config, outfile)
logger = logging.getLogger("LOG")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, 'logging_info'))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# add device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logging.info("Using device: {}".format(device))
config.device = device
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
return args, config, logger, log_dir
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config, logger, log_dir = parse_args_and_config()
print(">" * 80)
logging.info("Exp instance id = {}".format(os.getpid()))
logging.info("Exp comment = {}".format(args.comment))
logging.info("Config =")
print("<" * 80)
try:
runner = Diffusion(args, config, logger, log_dir)
runner.reconstruct()
except Exception:
logging.error(traceback.format_exc())
return 0
if __name__ == '__main__':
sys.exit(main())