-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
124 lines (105 loc) · 5.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sys
sys.path.insert(0, '.')
import os
import numpy as np
import argparse
import torch
import torch.backends.cudnn
import scipy.io as scio
from libs.builder import get_model_dataset_by_name
from libs.train import train
from libs.test import test
from libs.utils.set_seed import set_seed
def parser_args():
parser = argparse.ArgumentParser(description='An open source change detection toolbox based on PyTorch')
parser.add_argument('--data_name', default="SECOND",
choices=['LEVIR', 'LEVIR+', 'SYSU', 'S2Looking',
'SECOND', 'LandsatSCD',
'xview2'],
help='Data directory')
parser.add_argument('--model_name', default="A2NetMvit",
choices=['TFIGR', 'A2NetBCD', 'ARCDNetBCD', 'ChangeStar',
'A2Net', 'A2NetMvit', 'A2Net34', 'SSCDL', 'TED', 'BiSRNet', 'SCanNet',
'ChangeOS', 'ChangeOS-GRM', 'ARCDNet'],
help='Name of method')
parser.add_argument('--dataloader_name', default="bs_8",
choices=['bs_8', 'bs_16', 'bs_32'],
help='Batch size')
parser.add_argument('--is_train', type=int, default=1,
choices=[0, 1],
help='Is train model')
parser.add_argument('--save_dir', default='./weights/',
help='Directory to save the results')
parser.add_argument('--log_file', default='trainLog.txt',
help='File that stores the training and validation logs')
cmd_cfg = parser.parse_args()
cmd_cfg.save_dir = cmd_cfg.save_dir + cmd_cfg.model_name + '/' + cmd_cfg.data_name + '/'
cmd_cfg.pre_dir = cmd_cfg.save_dir + '/pre/'
cmd_cfg.post_dir = cmd_cfg.save_dir + '/post/'
cmd_cfg.log_file_loc = cmd_cfg.save_dir + cmd_cfg.log_file
cmd_cfg.model_file_name = cmd_cfg.save_dir + 'best_model.pth'
cmd_cfg.is_multi_scale_training = 0
cmd_cfg.is_multi_scale_testing = 0
print('Called with cmd_cfg:')
print(cmd_cfg)
return cmd_cfg
def main():
set_seed()
cmd_cfg = parser_args()
os.makedirs(cmd_cfg.save_dir, exist_ok=True)
os.makedirs(cmd_cfg.pre_dir, exist_ok=True)
os.makedirs(cmd_cfg.post_dir, exist_ok=True)
if cmd_cfg.is_train > 0:
(logger, model, train_loader, val_loader, test_loader, optimizer, scaler, optimizer_cfg,
task_type, task_cfg) = get_model_dataset_by_name(cmd_cfg)
model = model.cuda()
total_params = sum([np.prod(p.size()) for p in model.parameters()])
total_params = total_params / 1e6
print('Total network parameters (excluding idr): ' + str(total_params))
total_params_to_update = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params_to_update = total_params_to_update / 1e6
print('Total parameters to update: ' + str(total_params_to_update))
logger.write_parameters(total_params, total_params_to_update)
logger.write_header()
max_batches = len(train_loader)
print('For each epoch, we have {} batches'.format(max_batches))
max_epochs = optimizer_cfg['max_epoch']
eva_per_iters = optimizer_cfg['eva_per_iter']
start_epoch = 0
cur_iter = 0
cur_val_count = 0
for epoch in range(start_epoch, max_epochs):
loss_tr, score_tr, lr = train(cmd_cfg, task_type, task_cfg, optimizer_cfg, train_loader,
model, scaler, optimizer, max_batches, cur_iter)
cur_iter += len(train_loader)
torch.cuda.empty_cache()
logger.save_checkpoint(epoch, model, optimizer, loss_tr, score_tr, lr)
if cur_iter >= eva_per_iters * (cur_val_count + 1):
score_val = test(cmd_cfg, task_type, task_cfg, val_loader, model)
logger.write_val(epoch, loss_tr, score_tr, score_val)
logger.save_model(epoch, model, score_val)
torch.cuda.empty_cache()
cur_val_count += 1
state_dict = torch.load(cmd_cfg.model_file_name)
model.load_state_dict(state_dict)
score_test = test(cmd_cfg, task_type, task_cfg, test_loader, model)
logger.write_test(score_test)
scio.savemat(os.path.join(cmd_cfg.save_dir, 'results.mat'), score_test)
logger.close_logger()
else:
logger, model, test_loader, task_type, task_cfg = get_model_dataset_by_name(cmd_cfg)
model = model.cuda()
total_params = sum([np.prod(p.size()) for p in model.parameters()])
total_params = total_params / 1e6
print('Total network parameters (excluding idr): ' + str(total_params))
total_params_to_update = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params_to_update = total_params_to_update / 1e6
print('Total parameters to update: ' + str(total_params_to_update))
state_dict = torch.load(cmd_cfg.model_file_name)
model.load_state_dict(state_dict)
score_test = test(cmd_cfg, task_type, task_cfg, test_loader, model)
logger.write_test(score_test)
scio.savemat(os.path.join(cmd_cfg.save_dir, 'results.mat'), score_test)
logger.close_logger()
if __name__ == '__main__':
main()