Skip to content

Commit

Permalink
Correct mistake with parameters instead of buffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
AWehenkel committed Apr 18, 2020
1 parent 806d1b0 commit 8e0e63c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
7 changes: 2 additions & 5 deletions models/UMNN/UMNNMAF.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,16 @@ def __init__(self, net, input_size, nb_steps=100, device="cpu", solver="CC"):
self.cc_weights = None
self.steps = None
self.solver = solver
self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False)
self.register_buffer("pi", torch.tensor(math.pi))

# Scaling could be changed to be an autoregressive network output
self.scaling = nn.Parameter(torch.zeros(input_size, device=self.device), requires_grad=False)

def to(self, device):
self.device = device
self.net.to(device)
self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False)
self.scaling = self.scaling.to(self.device)
super().to(device)
return self


def forward(self, x, method=None, x0=None, context=None):
x0 = x0.to(x.device) if x0 is not None else torch.zeros(x.shape).to(x.device)
xT = x
Expand Down
4 changes: 2 additions & 2 deletions models/UMNN/UMNNMAFFlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, nb_flow=1, nb_in=1, hidden_derivative=[50, 50, 50, 50], hidde
"""
super().__init__()
self.device = device
self.pi = nn.Parameter(torch.tensor(math.pi, device=self.device), requires_grad=False)
self.register_buffer("pi", torch.tensor(math.pi))
self.nets = ListModule(self, "Flow")
for i in range(nb_flow):
auto_net = EmbeddingNetwork(nb_in, hidden_embedding, hidden_derivative, embedding_s, act_func=act_func,
Expand All @@ -66,7 +66,7 @@ def to(self, device):
for net in self.nets:
net.to(device)
self.device = device
self.pi = self.pi.to(device)
super().to(device)
return self

def forward(self, x, context=None):
Expand Down
4 changes: 3 additions & 1 deletion models/UMNN/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .UMNNMAFFlow import UMNNMAFFlow
from .MonotonicNN import MonotonicNN, IntegrandNN
from .UMNNMAF import IntegrandNetwork, UMNNMAF
from .made import MADE
from .made import MADE
from .NeuralIntegral import NeuralIntegral
from .ParallelNeuralIntegral import ParallelNeuralIntegral

0 comments on commit 8e0e63c

Please sign in to comment.