-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
124 lines (97 loc) · 4.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import time
import torch
import numpy as np
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from utils import *
from options.trainopt import _get_train_opt
import nyudv2_dataloader
from loss import cal_spatial_loss, cal_temporal_loss
from resnet import resnet18
import modules
import net
cudnn.benchmark = True
args = _get_train_opt
os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
makedir(args.checkpoint_dir)
makedir(args.logdir)
logger = SummaryWriter(args.logdir)
TrainImgLoader = nyudv2_dataloader.getTrainingData_NYUDV2(args.batch_size, args.trainlist_path, args.root_path)
device = 'cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu'
Encoder = modules.E_resnet(resnet18)
if args.backbone in ['resnet50']:
model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], refinenet=args.refinenet)
elif args.backbone in ['resnet18', 'resnet34']:
model = net.model(Encoder, num_features=512, block_channel=[64, 128, 256, 512], refinenet=args.refinenet)
model = nn.DataParallel(model).cuda()
disc = net.C_C3D_1().cuda()
optimizer = build_optimizer(model=model,
learning_rate=args.lr,
optimizer_name=args.optimizer_name,
weight_decay=args.weight_decay,
epsilon=args.epsilon,
momentum=args.momentum
)
start_epoch = 0
if args.resume:
all_saved_ckpts = [ckpt for ckpt in os.listdir(args.checkpoint_dir) if ckpt.endswith(".pth.tar")]
print(all_saved_ckpts)
all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x: int(x.split('_')[-1].split('.')[0]))
loadckpt = os.path.join(args.checkpoint_dir, all_saved_ckpts[-1])
start_epoch = int(all_saved_ckpts[-1].split('_')[-1].split('.')[0])
print("loading the lastest model in checkpoint_dir: {}".format(loadckpt))
state_dict = torch.load(loadckpt)
model.load_state_dict(state_dict)
elif args.loadckpt is not None:
print("loading model {}".format(args.loadckpt))
start_epoch = args.loadckpt.split('_')[-1].split('.')[0]
state_dict = torch.load(args.loadckpt)
model.load_state_dict(state_dict)
else:
print("start at epoch {}".format(start_epoch))
def train():
for epoch in range(start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch, args.lr)
batch_time = AverageMeter()
losses = AverageMeter()
model.train()
end = time.time()
for batch_idx, sample in enumerate(TrainImgLoader):
image, depth = sample[0], sample[1] # (b,c,d,w,h)
depth = depth.cuda()
image = image.cuda()
image = torch.autograd.Variable(image)
depth = torch.autograd.Variable(depth)
optimizer.zero_grad()
global_step = len(TrainImgLoader) * epoch + batch_idx
gt_depth = depth
pred_depth = model(image) # (b, c, d, h, w)
spatial_losses = []
for seq_idx in range(image.size(2)):
spatial_loss = cal_spatial_loss(pred_depth[:, :, seq_idx, :, :], gt_depth[:, :, seq_idx, :, :])
spatial_losses.append(spatial_loss)
spatial_loss = sum(spatial_losses)
pred_cls = disc(pred_depth)
gt_cls = disc(gt_depth)
temporal_loss = cal_temporal_loss(pred_cls, gt_cls)
loss = spatial_loss + 0.1 * temporal_loss
losses.update(loss.item(), image.size(0))
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
batchSize = depth.size(0)
print(('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})'
.format(epoch, batch_idx, len(TrainImgLoader), batch_time=batch_time, loss=losses)))
if (epoch + 1) % 1 == 0:
save_checkpoint(model.state_dict(),
filename=args.checkpoint_dir + "ResNet18_checkpoints_small_" + str(epoch + 1) + ".pth.tar")
if __name__ == '__main__':
train()