-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
51 lines (40 loc) · 1.65 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch import nn
from torch.nn import utils
# DeepSDF network
class DeepSDF(nn.Module):
def __init__(self, num_layers=8, hidden_dim=512, use_weight_norm=True,
use_dropout=True, dropout_prob=0.2):
super().__init__()
self.first_fc = nn.Linear(3, hidden_dim)
self.inter_fc = nn.ModuleList()
num_inter_layers = num_layers - 2
self.skip_layer = num_inter_layers // 2
for i in range(num_inter_layers):
if i == self.skip_layer: # Skip connection
out_dim = hidden_dim - 3
else:
out_dim = hidden_dim
self.inter_fc.append(nn.Linear(hidden_dim, out_dim))
self.last_fc = nn.Linear(hidden_dim, 1)
if use_weight_norm:
self.apply_weight_norm()
if use_dropout:
self.activation = nn.Sequential(nn.ReLU(), nn.Dropout(dropout_prob))
else:
self.activation = nn.ReLU()
def apply_weight_norm(self):
self.first_fc = utils.weight_norm(self.first_fc)
self.last_fc = utils.weight_norm(self.last_fc)
for i, layer in enumerate(self.inter_fc):
self.inter_fc[i] = utils.weight_norm(layer)
def forward(self, coords):
# coords: 3D point coordinates. [B, 3]
out = self.activation(self.first_fc(coords)) # [B, H]
for i, inter_layer in enumerate(self.inter_fc):
out = self.activation(inter_layer(out))
if i == self.skip_layer:
out = torch.cat([out, coords], dim=-1) # [B, H - 3] -> [B, H]
out = self.last_fc(out) # [B, 1]
out = torch.tanh(out)
return out