forked from mynameischaos/GCC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
selflabel.py
136 lines (115 loc) · 5.02 KB
/
selflabel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Authors: Huasong Zhong
Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
"""
import argparse
import os
import torch
from utils.config import create_config
from utils.common_config import get_train_dataset, get_train_transformations,\
get_val_dataset, get_val_transformations,\
get_train_dataloader, get_val_dataloader,\
get_optimizer, get_model, adjust_learning_rate,\
get_criterion
from utils.ema import EMA
from utils.evaluate_utils import get_predictions, hungarian_evaluate
from utils.train_utils import selflabel_train
from termcolor import colored
# Parser
parser = argparse.ArgumentParser(description='Self-labeling')
parser.add_argument('--config_env',
help='Config file for the environment')
parser.add_argument('--config_exp',
help='Config file for the experiment')
args = parser.parse_args()
def main():
# Retrieve config file
p = create_config(args.config_env, args.config_exp)
print(colored(p, 'red'))
with open (p['log_output_file'], 'a+') as fw:
fw.write(str(p) + "\n")
# Get model
print(colored('Retrieve model', 'blue'))
#model = get_model(p, p['scan_model'])
model = get_model(p, p['end2end_checkpoint'])
print(model)
model = torch.nn.DataParallel(model)
model = model.cuda()
# Get criterion
print(colored('Get loss', 'blue'))
criterion = get_criterion(p)
criterion.cuda()
print(criterion)
# CUDNN
print(colored('Set CuDNN benchmark', 'blue'))
torch.backends.cudnn.benchmark = True
# Optimizer
print(colored('Retrieve optimizer', 'blue'))
optimizer = get_optimizer(p, model)
print(optimizer)
# Dataset
print(colored('Retrieve dataset', 'blue'))
# Transforms
strong_transforms = get_train_transformations(p)
val_transforms = get_val_transformations(p)
train_dataset = get_train_dataset(p, {'standard': val_transforms, 'augment': strong_transforms},
split='train', to_augmented_dataset=True)
train_dataloader = get_train_dataloader(p, train_dataset)
val_dataset = get_val_dataset(p, val_transforms)
val_dataloader = get_val_dataloader(p, val_dataset)
print(colored('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)), 'yellow'))
# Checkpoint
if os.path.exists(p['selflabel_checkpoint']):
print(colored('Restart from checkpoint {}'.format(p['selflabel_checkpoint']), 'blue'))
checkpoint = torch.load(p['selflabel_checkpoint'], map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
else:
print(colored('No checkpoint file at {}'.format(p['selflabel_checkpoint']), 'blue'))
start_epoch = 0
# EMA
if p['use_ema']:
ema = EMA(model, alpha=p['ema_alpha'])
else:
ema = None
# Main loop
print(colored('Starting main loop', 'blue'))
best_acc = 0.0
best_res = None
for epoch in range(start_epoch, p['epochs']):
print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow'))
print(colored('-'*10, 'yellow'))
# Adjust lr
lr = adjust_learning_rate(p, optimizer, epoch)
print('Adjusted learning rate to {:.5f}'.format(lr))
# Perform self-labeling
print('Train ...')
selflabel_train(train_dataloader, model, criterion, optimizer, epoch, ema=ema, output_file=p['log_output_file'])
# Evaluate (To monitor progress - Not for validation)
print('Evaluate ...')
predictions = get_predictions(p, val_dataloader, model)
clustering_stats = hungarian_evaluate(0, predictions, compute_confusion_matrix=False)
print(clustering_stats)
if clustering_stats["ACC"] > best_acc:
best_acc = clustering_stats["ACC"]
best_res = clustering_stats
with open (p['log_output_file'], 'a+') as fw:
fw.write(str(clustering_stats) + "\n")
print("best: {}".format(best_res))
# Checkpoint
print('Checkpoint ...')
torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(),
'epoch': epoch + 1}, p['selflabel_checkpoint'])
torch.save(model.module.state_dict(), p['selflabel_model'])
# Evaluate and save the final model
print(colored('Evaluate model at the end', 'blue'))
predictions = get_predictions(p, val_dataloader, model, self_labeling=True)
clustering_stats = hungarian_evaluate(0, predictions,
class_names=val_dataset.classes,
compute_confusion_matrix=False,
confusion_matrix_file=os.path.join(p['selflabel_dir'], 'confusion_matrix.png'))
print(clustering_stats)
torch.save(model.module.state_dict(), p['selflabel_model'])
if __name__ == "__main__":
main()