-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
292 lines (241 loc) · 11.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import torch
from torch import nn
from opt import get_opts
import os
import glob
import imageio
import numpy as np
import cv2
from einops import rearrange
# data
from torch.utils.data import DataLoader
from datasets import dataset_dict
from datasets.ray_utils import axisangle_to_R, get_rays
# models
from kornia.utils.grid import create_meshgrid3d
from models.networks import NGP
from models.rendering import render, MAX_SAMPLES
# optimizer, losses
from apex.optimizers import FusedAdam
from torch.optim.lr_scheduler import CosineAnnealingLR
from losses import NeRFLoss
# metrics
from torchmetrics import (
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure
)
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
# pytorch-lightning
# from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.utilities. import all_gather_ddp_if_available
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from utils import slim_ckpt, load_ckpt
import warnings; warnings.filterwarnings("ignore")
def depth2img(depth):
depth = (depth-depth.min())/(depth.max()-depth.min())
depth_img = cv2.applyColorMap((depth*255).astype(np.uint8),
cv2.COLORMAP_TURBO)
return depth_img
class NeRFSystem(LightningModule):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
self.warmup_steps = 256
self.update_interval = 16
self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w)
self.train_psnr = PeakSignalNoiseRatio(data_range=1)
self.val_psnr = PeakSignalNoiseRatio(data_range=1)
self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1)
if self.hparams.eval_lpips:
self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg')
for p in self.val_lpips.net.parameters():
p.requires_grad = False
rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid'
self.model = NGP(scale=self.hparams.scale, rgb_act=rgb_act)
G = self.model.grid_size
self.model.register_buffer('density_grid',
torch.zeros(self.model.cascades, G**3))
self.model.register_buffer('grid_coords',
create_meshgrid3d(G, G, G, False, dtype=torch.int32).reshape(-1, 3))
def forward(self, batch, split):
if split=='train':
poses = self.poses[batch['img_idxs']]
directions = self.directions[batch['pix_idxs']]
else:
poses = batch['pose']
directions = self.directions
if self.hparams.optimize_ext:
dR = axisangle_to_R(self.dR[batch['img_idxs']])
poses[..., :3] = dR @ poses[..., :3]
poses[..., 3] += self.dT[batch['img_idxs']]
rays_o, rays_d = get_rays(directions, poses)
kwargs = {'test_time': split!='train',
'random_bg': self.hparams.random_bg}
if self.hparams.scale > 0.5:
kwargs['exp_step_factor'] = 1/256
if self.hparams.use_exposure:
kwargs['exposure'] = batch['exposure']
return render(self.model, rays_o, rays_d, **kwargs)
def setup(self, stage):
dataset = dataset_dict[self.hparams.dataset_name]
kwargs = {'root_dir': self.hparams.root_dir,
'downsample': self.hparams.downsample}
self.train_dataset = dataset(split=self.hparams.split, **kwargs)
self.train_dataset.batch_size = self.hparams.batch_size
self.train_dataset.ray_sampling_strategy = self.hparams.ray_sampling_strategy
self.test_dataset = dataset(split='test', **kwargs)
def configure_optimizers(self):
# define additional parameters
self.register_buffer('directions', self.train_dataset.directions.to(self.device))
self.register_buffer('poses', self.train_dataset.poses.to(self.device))
if self.hparams.optimize_ext:
N = len(self.train_dataset.poses)
self.register_parameter('dR',
nn.Parameter(torch.zeros(N, 3, device=self.device)))
self.register_parameter('dT',
nn.Parameter(torch.zeros(N, 3, device=self.device)))
load_ckpt(self.model, self.hparams.weight_path)
net_params = []
for n, p in self.named_parameters():
if n not in ['dR', 'dT']: net_params += [p]
opts = []
self.net_opt = FusedAdam(net_params, self.hparams.lr, eps=1e-15)
opts += [self.net_opt]
if self.hparams.optimize_ext:
opts += [FusedAdam([self.dR, self.dT], 1e-6)] # learning rate is hard-coded
net_sch = CosineAnnealingLR(self.net_opt,
self.hparams.num_epochs,
self.hparams.lr/30)
return opts, [net_sch]
def train_dataloader(self):
return DataLoader(self.train_dataset,
num_workers=16,
persistent_workers=True,
batch_size=None,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.test_dataset,
num_workers=8,
batch_size=None,
pin_memory=True)
def on_train_start(self):
self.model.mark_invisible_cells(self.train_dataset.K.to(self.device),
self.poses,
self.train_dataset.img_wh)
def training_step(self, batch, batch_nb, *args):
if self.global_step%self.update_interval == 0:
self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5,
warmup=self.global_step<self.warmup_steps,
erode=self.hparams.dataset_name=='colmap')
results = self(batch, split='train')
loss_d = self.loss(results, batch)
if self.hparams.use_exposure:
zero_radiance = torch.zeros(1, 3, device=self.device)
unit_exposure_rgb = self.model.log_radiance_to_rgb(zero_radiance,
**{'exposure': torch.ones(1, 1, device=self.device)})
loss_d['unit_exposure'] = \
0.5*(unit_exposure_rgb-self.train_dataset.unit_exposure_rgb)**2
loss = sum(lo.mean() for lo in loss_d.values())
with torch.no_grad():
self.train_psnr(results['rgb'], batch['rgb'])
self.log('lr', self.net_opt.param_groups[0]['lr'])
self.log('train/loss', loss)
# ray marching samples per ray (occupied space on the ray)
self.log('train/rm_s', results['rm_samples']/len(batch['rgb']), True)
# volume rendering samples per ray (stops marching when transmittance drops below 1e-4)
self.log('train/vr_s', results['vr_samples']/len(batch['rgb']), True)
self.log('train/psnr', self.train_psnr, True)
return loss
def on_validation_start(self):
torch.cuda.empty_cache()
if not self.hparams.no_save_test:
self.val_dir = f'results/{self.hparams.dataset_name}/{self.hparams.exp_name}'
os.makedirs(self.val_dir, exist_ok=True)
def validation_step(self, batch, batch_nb):
rgb_gt = batch['rgb']
results = self(batch, split='test')
logs = {}
# compute each metric per image
self.val_psnr(results['rgb'], rgb_gt)
logs['psnr'] = self.val_psnr.compute()
self.val_psnr.reset()
w, h = self.train_dataset.img_wh
rgb_pred = rearrange(results['rgb'], '(h w) c -> 1 c h w', h=h)
rgb_gt = rearrange(rgb_gt, '(h w) c -> 1 c h w', h=h)
self.val_ssim(rgb_pred, rgb_gt)
logs['ssim'] = self.val_ssim.compute()
self.val_ssim.reset()
if self.hparams.eval_lpips:
self.val_lpips(torch.clip(rgb_pred*2-1, -1, 1),
torch.clip(rgb_gt*2-1, -1, 1))
logs['lpips'] = self.val_lpips.compute()
self.val_lpips.reset()
if not self.hparams.no_save_test: # save test image to disk
idx = batch['img_idxs']
rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h)
rgb_pred = (rgb_pred*255).astype(np.uint8)
depth = depth2img(rearrange(results['depth'].cpu().numpy(), '(h w) -> h w', h=h))
imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}.png'), rgb_pred)
imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}_d.png'), depth)
return logs
# def on_validation_epoch_end(self, outputs):
# psnrs = torch.stack([x['psnr'] for x in outputs])
# mean_psnr = psnrs.mean()
# self.log('test/psnr', mean_psnr, True)
#
# ssims = torch.stack([x['ssim'] for x in outputs])
# mean_ssim = ssims.mean()
# self.log('test/ssim', mean_ssim)
#
# if self.hparams.eval_lpips:
# lpipss = torch.stack([x['lpips'] for x in outputs])
# mean_lpips = lpipss.mean()
# self.log('test/lpips_vgg', mean_lpips)
def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items
if __name__ == '__main__':
hparams = get_opts()
if hparams.val_only and (not hparams.ckpt_path):
raise ValueError('You need to provide a @ckpt_path for validation!')
system = NeRFSystem(hparams)
ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.dataset_name}/{hparams.exp_name}',
filename='{epoch:d}',
save_weights_only=True,
every_n_epochs=hparams.num_epochs,
save_on_train_epoch_end=True,
save_top_k=-1)
callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)]
logger = TensorBoardLogger(save_dir=f"logs/{hparams.dataset_name}",
name=hparams.exp_name,
default_hp_metric=False)
trainer = Trainer(max_epochs=hparams.num_epochs,
check_val_every_n_epoch=hparams.num_epochs,
callbacks=callbacks,
logger=logger,
enable_model_summary=False,
accelerator='gpu',
devices=hparams.num_gpus,
num_sanity_val_steps=-1 if hparams.val_only else 0,
precision=16)
trainer.fit(system, ckpt_path=hparams.ckpt_path)
if not hparams.val_only: # save slimmed ckpt for the last epoch
ckpt_ = \
slim_ckpt(f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/epoch={hparams.num_epochs-1}.ckpt',
save_poses=hparams.optimize_ext)
torch.save(ckpt_, f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/epoch={hparams.num_epochs-1}_slim.ckpt')
if (not hparams.no_save_test) and \
hparams.dataset_name=='nsvf' and \
'Synthetic' in hparams.root_dir: # save video
imgs = sorted(glob.glob(os.path.join(system.val_dir, '*.png')))
imageio.mimsave(os.path.join(system.val_dir, 'rgb.mp4'),
[imageio.imread(img) for img in imgs[::2]],
fps=30, macro_block_size=1)
imageio.mimsave(os.path.join(system.val_dir, 'depth.mp4'),
[imageio.imread(img) for img in imgs[1::2]],
fps=30, macro_block_size=1)