From 4d986efda8b4a3a8feef8988abc594b4b4ec3273 Mon Sep 17 00:00:00 2001 From: awehenkel Date: Thu, 16 Apr 2020 10:57:33 +0200 Subject: [PATCH] Clean code with gpu transfer and parameters definition --- models/UMNN/UMNNMAF.py | 8 ++++---- models/UMNN/UMNNMAFFlow.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/models/UMNN/UMNNMAF.py b/models/UMNN/UMNNMAF.py index 799c66f..243cb2e 100644 --- a/models/UMNN/UMNNMAF.py +++ b/models/UMNN/UMNNMAF.py @@ -41,16 +41,16 @@ def __init__(self, net, input_size, nb_steps=100, device="cpu", solver="CC"): self.cc_weights = None self.steps = None self.solver = solver - self.pi = torch.tensor(math.pi).to(self.device) + self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False) # Scaling could be changed to be an autoregressive network output - self.scaling = torch.zeros(input_size).to(self.device) + self.scaling = nn.Parameter(torch.zeros(input_size, device=self.device), requires_grad=False) def to(self, device): self.device = device self.net.to(device) - self.pi = torch.tensor(math.pi).to(self.device) - self.scaling.to(self.device) + self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False) + self.scaling = self.scaling.to(self.device) return self diff --git a/models/UMNN/UMNNMAFFlow.py b/models/UMNN/UMNNMAFFlow.py index a98369f..e78b579 100644 --- a/models/UMNN/UMNNMAFFlow.py +++ b/models/UMNN/UMNNMAFFlow.py @@ -53,7 +53,7 @@ def __init__(self, nb_flow=1, nb_in=1, hidden_derivative=[50, 50, 50, 50], hidde """ super().__init__() self.device = device - self.pi = torch.tensor(math.pi).to(self.device) + self.pi = nn.Parameter(torch.tensor(math.pi, device=self.device), requires_grad=False) self.nets = ListModule(self, "Flow") for i in range(nb_flow): auto_net = EmbeddingNetwork(nb_in, hidden_embedding, hidden_derivative, embedding_s, act_func=act_func, @@ -66,7 +66,7 @@ def to(self, device): for net in self.nets: net.to(device) self.device = device - self.pi = torch.tensor(math.pi).to(device) + self.pi = self.pi.to(device) return self def forward(self, x, context=None):