-
Notifications
You must be signed in to change notification settings - Fork 29
/
run_nerf_helpers.py
99 lines (75 loc) · 3.27 KB
/
run_nerf_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from torchsearchsorted import searchsorted
import numpy as np
import torch
torch.autograd.set_detect_anomaly(True)
TEST = False
# Misc
def img2mse(x, y): return torch.mean((x - y) ** 2)
def img2l1(x, y): return torch.mean((x - y).abs())
def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
def to8b(x): return (255 * np.clip(x, 0, 1)).astype(np.uint8)
def to_disp_img(disp):
# clip outliers
#disp = 1. / disp
min_disp, max_disp = np.percentile(disp, [5, 95])
disp[disp < min_disp] = min_disp
disp[disp > max_disp] = max_disp
# disp = disp - disp.min() #normalize to have [0, max]
disp = disp / disp.max() # normalize in [0, 1]
return disp
# Ray helpers
def get_rays(H, W, focal, c2w):
i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
wfactor, hfactor = focal.item(), focal.item()
if focal < 10: # super hacky
# normalize to [-1, 1]
wfactor *= (W * .5)
hfactor *= (H * .5)
# inside [-200, 200] (400/2), we only want to render from [-128/200, 128/200]
wfactor *= (200. / 128.)
hfactor *= (200. / 128.)
dirs = torch.stack([(i - W * .5) / wfactor, -(j - H * .5) / hfactor, -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3, -1].expand(rays_d.shape)
return rays_o, rays_d
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF
u = u.contiguous()
inds = searchsorted(cdf, u, side='right')
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples