Skip to content

Commit

Permalink
Clean code with gpu transfer and parameters definition
Browse files Browse the repository at this point in the history
  • Loading branch information
AWehenkel committed Apr 16, 2020
1 parent fcd7040 commit 4d986ef
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions models/UMNN/UMNNMAF.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ def __init__(self, net, input_size, nb_steps=100, device="cpu", solver="CC"):
self.cc_weights = None
self.steps = None
self.solver = solver
self.pi = torch.tensor(math.pi).to(self.device)
self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False)

# Scaling could be changed to be an autoregressive network output
self.scaling = torch.zeros(input_size).to(self.device)
self.scaling = nn.Parameter(torch.zeros(input_size, device=self.device), requires_grad=False)

def to(self, device):
self.device = device
self.net.to(device)
self.pi = torch.tensor(math.pi).to(self.device)
self.scaling.to(self.device)
self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False)
self.scaling = self.scaling.to(self.device)
return self


Expand Down
4 changes: 2 additions & 2 deletions models/UMNN/UMNNMAFFlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, nb_flow=1, nb_in=1, hidden_derivative=[50, 50, 50, 50], hidde
"""
super().__init__()
self.device = device
self.pi = torch.tensor(math.pi).to(self.device)
self.pi = nn.Parameter(torch.tensor(math.pi, device=self.device), requires_grad=False)
self.nets = ListModule(self, "Flow")
for i in range(nb_flow):
auto_net = EmbeddingNetwork(nb_in, hidden_embedding, hidden_derivative, embedding_s, act_func=act_func,
Expand All @@ -66,7 +66,7 @@ def to(self, device):
for net in self.nets:
net.to(device)
self.device = device
self.pi = torch.tensor(math.pi).to(device)
self.pi = self.pi.to(device)
return self

def forward(self, x, context=None):
Expand Down

0 comments on commit 4d986ef

Please sign in to comment.