-
Notifications
You must be signed in to change notification settings - Fork 4
/
export_colorray.py
107 lines (89 loc) · 3.48 KB
/
export_colorray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import cv2
import glob2
import tqdm
import torch
# torch.set_default_dtype(torch.float16)
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
# import skimage
import matplotlib.pyplot as plt
import time
from functools import partial
import dataio, modules
import configargparse
p = configargparse.ArgumentParser()
# p.add_argument('--c', required=True, type=int)
p.add_argument('--save_dir', required=True, type=str)
# /home/dejia/repo/siren/data/div2k_color_hole_grad
p.add_argument('--load', required=True, type=str)
p.add_argument('--offset', default=0, type=int)
p.add_argument('--single', action='store_true')
# div2k_*.png_color_hole
args = p.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
def get_mgrid(sidelen, dim=2):
'''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
sidelen: int
dim: int'''
tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
mgrid = mgrid.reshape(-1, dim)
return mgrid
li = glob2.glob(f'./logs/{args.load}/checkpoints/model_final.pth')
print(len(li))
# data
sz = 256
img_dataset = dataio.NoisyCamera_multimlp_rays(img_path='div2k', img_num=1)
coord_dataset = dataio.Implicit2DWrapper_multimlp_ray(img_dataset, sidelength=sz, compute_diff='blur_x', sigma=5)
image_resolution = (sz, sz)
dataloader = DataLoader(coord_dataset, shuffle=False, batch_size=4096, pin_memory=False, num_workers=4)
# len is num of image
device = torch.device('cuda:0')
model = modules.SingleBVPNet(type='sine', mode='mlp', sidelength=image_resolution, hidden_features=256, num_hidden_layers=3, out_features=3).to(device)
def extract_image(model, dataloader, image_resolution, device):
output = []
for step, (model_input, gt) in enumerate((dataloader)):
model_input = {key: value.to(device) for key, value in model_input.items() if key != 'ckpt'}
model_output = model(model_input)
output.append(model_output['model_out'].detach().cpu().numpy())
output = np.concatenate(output, 0)
return output
import diff_operators
def gradient(y, x, grad_outputs=None):
if grad_outputs is None:
grad_outputs = torch.ones_like(y)
grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
return grad
def new_grad_lastdim(y, x, sz=256, num=31):
li = [y[..., 0], y[..., 1], y[..., 2]]
for i in range(num):
cur = li[i]
ww = torch.autograd.grad(cur / sz, x, torch.ones_like(cur), create_graph=True)[0]
li.append(ww[..., 0])
li.append(ww[..., 1])
return torch.stack(li, dim=-1)
def grad_model(model, model_input):
model_output = model(model_input)
y, x = model_output['model_out'], model_output['model_in']
new = new_grad_lastdim(y, x, 63 * 3)
return {'model_out': new}
def process(fname):
model.load_state_dict(torch.load(fname))
g_model = lambda model_input: grad_model(model, model_input)
g_output = extract_image(g_model, dataloader, image_resolution, device)#.squeeze(-1)
print(g_output.shape)
new = os.path.join(args.save_dir, os.path.basename(fname.split('/')[-3]).replace('.png', f'.npy'))
np.save(new, g_output)
num = len(li)
if args.single:
li = sorted(li)[args.offset:args.offset + 1]
else:
li = sorted(li)[args.offset:]
print(li)
for cur in tqdm.tqdm(li):
process(cur)