diff --git a/models/UMNN/UMNNMAF.py b/models/UMNN/UMNNMAF.py index a3bff49..bef11ce 100644 --- a/models/UMNN/UMNNMAF.py +++ b/models/UMNN/UMNNMAF.py @@ -41,19 +41,16 @@ def __init__(self, net, input_size, nb_steps=100, device="cpu", solver="CC"): self.cc_weights = None self.steps = None self.solver = solver - self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False) + self.register_buffer("pi", torch.tensor(math.pi)) # Scaling could be changed to be an autoregressive network output self.scaling = nn.Parameter(torch.zeros(input_size, device=self.device), requires_grad=False) def to(self, device): self.device = device - self.net.to(device) - self.pi = nn.Parameter(torch.tensor(math.pi).to(self.device), requires_grad=False) - self.scaling = self.scaling.to(self.device) + super().to(device) return self - def forward(self, x, method=None, x0=None, context=None): x0 = x0.to(x.device) if x0 is not None else torch.zeros(x.shape).to(x.device) xT = x diff --git a/models/UMNN/UMNNMAFFlow.py b/models/UMNN/UMNNMAFFlow.py index e78b579..4ff1952 100644 --- a/models/UMNN/UMNNMAFFlow.py +++ b/models/UMNN/UMNNMAFFlow.py @@ -53,7 +53,7 @@ def __init__(self, nb_flow=1, nb_in=1, hidden_derivative=[50, 50, 50, 50], hidde """ super().__init__() self.device = device - self.pi = nn.Parameter(torch.tensor(math.pi, device=self.device), requires_grad=False) + self.register_buffer("pi", torch.tensor(math.pi)) self.nets = ListModule(self, "Flow") for i in range(nb_flow): auto_net = EmbeddingNetwork(nb_in, hidden_embedding, hidden_derivative, embedding_s, act_func=act_func, @@ -66,7 +66,7 @@ def to(self, device): for net in self.nets: net.to(device) self.device = device - self.pi = self.pi.to(device) + super().to(device) return self def forward(self, x, context=None): diff --git a/models/UMNN/__init__.py b/models/UMNN/__init__.py index b73bcfc..38d3c1a 100644 --- a/models/UMNN/__init__.py +++ b/models/UMNN/__init__.py @@ -1,4 +1,6 @@ from .UMNNMAFFlow import UMNNMAFFlow from .MonotonicNN import MonotonicNN, IntegrandNN from .UMNNMAF import IntegrandNetwork, UMNNMAF -from .made import MADE \ No newline at end of file +from .made import MADE +from .NeuralIntegral import NeuralIntegral +from .ParallelNeuralIntegral import ParallelNeuralIntegral \ No newline at end of file