diff --git a/models/UMNN/NeuralIntegral.py b/models/UMNN/NeuralIntegral.py index c30188c..99f80b5 100644 --- a/models/UMNN/NeuralIntegral.py +++ b/models/UMNN/NeuralIntegral.py @@ -29,7 +29,10 @@ def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot= #Clenshaw-Curtis Quadrature Method cc_weights, steps = compute_cc_weights(nb_steps) - device = x0.get_device() if x0.is_cuda else "cpu" + device = x0.get_device() if x0.is_cuda or x0.is_mps else "cpu" + if x0.is_mps: + device = 'mps' + cc_weights, steps = cc_weights.to(device), steps.to(device) if compute_grad: diff --git a/models/UMNN/ParallelNeuralIntegral.py b/models/UMNN/ParallelNeuralIntegral.py index 4fa2038..de26f4e 100644 --- a/models/UMNN/ParallelNeuralIntegral.py +++ b/models/UMNN/ParallelNeuralIntegral.py @@ -25,11 +25,14 @@ def compute_cc_weights(nb_steps): return cc_weights, steps -def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=None): +def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=None, inv_f=False): #Clenshaw-Curtis Quadrature Method cc_weights, steps = compute_cc_weights(nb_steps) - device = x0.get_device() if x0.is_cuda else "cpu" + device = x0.get_device() if x0.is_cuda or x0.is_mps else "cpu" + if x0.is_mps: + device = 'mps' + cc_weights, steps = cc_weights.to(device), steps.to(device) xT = x0 + nb_steps*step_sizes @@ -41,7 +44,10 @@ def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot= X_steps = x0_t + (xT_t-x0_t)*(steps_t + 1)/2 X_steps = X_steps.contiguous().view(-1, x0_t.shape[2]) h_steps = h_steps.contiguous().view(-1, h.shape[1]) - dzs = integrand(X_steps, h_steps) + if inv_f: + dzs = 1/integrand(X_steps, h_steps) + else: + dzs = integrand(X_steps, h_steps) dzs = dzs.view(xT_t.shape[0], nb_steps+1, -1) dzs = dzs*cc_weights.unsqueeze(0).expand(dzs.shape) z_est = dzs.sum(1) @@ -59,14 +65,18 @@ def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot= h_steps = h_steps.contiguous().view(-1, h.shape[1]) x_tot_steps = x_tot_steps.contiguous().view(-1, x_tot.shape[1]) - g_param, g_h = computeIntegrand(X_steps, h_steps, integrand, x_tot_steps, nb_steps+1) + g_param, g_h = computeIntegrand(X_steps, h_steps, integrand, x_tot_steps, nb_steps+1, inv_f=inv_f) return g_param, g_h -def computeIntegrand(x, h, integrand, x_tot, nb_steps): +def computeIntegrand(x, h, integrand, x_tot, nb_steps, inv_f=False): h.requires_grad_(True) with torch.enable_grad(): - f = integrand.forward(x, h) + if inv_f: + f = 1/integrand.forward(x, h) + else: + f = integrand.forward(x, h) + g_param = _flatten(torch.autograd.grad(f, integrand.parameters(), x_tot, create_graph=True, retain_graph=True)) g_h = _flatten(torch.autograd.grad(f, h, x_tot)) @@ -76,12 +86,13 @@ def computeIntegrand(x, h, integrand, x_tot, nb_steps): class ParallelNeuralIntegral(torch.autograd.Function): @staticmethod - def forward(ctx, x0, x, integrand, flat_params, h, nb_steps=20): + def forward(ctx, x0, x, integrand, flat_params, h, nb_steps=20, inv_f=False): with torch.no_grad(): - x_tot = integrate(x0, nb_steps, (x - x0)/nb_steps, integrand, h, False) + x_tot = integrate(x0, nb_steps, (x - x0)/nb_steps, integrand, h, False, inv_f=inv_f) # Save for backward ctx.integrand = integrand ctx.nb_steps = nb_steps + ctx.inv_f = inv_f ctx.save_for_backward(x0.clone(), x.clone(), h) return x_tot @@ -90,7 +101,8 @@ def backward(ctx, grad_output): x0, x, h = ctx.saved_tensors integrand = ctx.integrand nb_steps = ctx.nb_steps - integrand_grad, h_grad = integrate(x0, nb_steps, x/nb_steps, integrand, h, True, grad_output) + inv_f = ctx.inv_f + integrand_grad, h_grad = integrate(x0, nb_steps, x/nb_steps, integrand, h, True, grad_output, inv_f) x_grad = integrand(x, h) x0_grad = integrand(x0, h) # Leibniz formula