-
Notifications
You must be signed in to change notification settings - Fork 0
/
INR.py
54 lines (43 loc) · 1.89 KB
/
INR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
from gabor_layer import GaborLayer2D
from symbolic_layer import SymbolicLayer
class INR(nn.Module):
def __init__(self, in_features, hidden_features,
hidden_layers,
out_features, outermost_linear=True,
first_omega_0=10, hidden_omega_0=10., scale=10.0,
pos_encode=False, sidelength=512, fn_samples=None,
use_nyquist=True):
super().__init__()
# All results in the paper were with the default complex 'gabor' nonlinearity
self.nonlin = GaborLayer2D
# Since complex numbers are two real numbers, reduce the number of
# hidden parameters by 4
hidden_features = int(hidden_features/2)
dtype = torch.float
self.complex = True
self.wavelet = 'gabor'
# Legacy parameter
self.pos_encode = False
self.net = []
self.net.append(self.nonlin(in_features,
hidden_features,
omega0=first_omega_0,
sigma0=scale,
is_first=True,
trainable=True))
for i in range(hidden_layers):
self.net.append(self.nonlin(hidden_features,
hidden_features,
omega0=hidden_omega_0,
sigma0=scale))
final_linear = SymbolicLayer(hidden_features,
out_features)
self.net.append(final_linear)
self.net = nn.Sequential(*self.net)
def forward(self, coords):
output = self.net(coords)
# if self.wavelet == 'gabor':
# return output.real
return output