Skip to content

Commit

Permalink
add compatibility with mps
Browse files Browse the repository at this point in the history
  • Loading branch information
AWehenkel committed Dec 18, 2023
1 parent e893457 commit 9e489b5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
5 changes: 4 additions & 1 deletion models/UMNN/NeuralIntegral.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=
#Clenshaw-Curtis Quadrature Method
cc_weights, steps = compute_cc_weights(nb_steps)

device = x0.get_device() if x0.is_cuda else "cpu"
device = x0.get_device() if x0.is_cuda or x0.is_mps else "cpu"
if x0.is_mps:
device = 'mps'

cc_weights, steps = cc_weights.to(device), steps.to(device)

if compute_grad:
Expand Down
30 changes: 21 additions & 9 deletions models/UMNN/ParallelNeuralIntegral.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ def compute_cc_weights(nb_steps):
return cc_weights, steps


def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=None):
def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=None, inv_f=False):
#Clenshaw-Curtis Quadrature Method
cc_weights, steps = compute_cc_weights(nb_steps)

device = x0.get_device() if x0.is_cuda else "cpu"
device = x0.get_device() if x0.is_cuda or x0.is_mps else "cpu"
if x0.is_mps:
device = 'mps'

cc_weights, steps = cc_weights.to(device), steps.to(device)

xT = x0 + nb_steps*step_sizes
Expand All @@ -41,7 +44,10 @@ def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=
X_steps = x0_t + (xT_t-x0_t)*(steps_t + 1)/2
X_steps = X_steps.contiguous().view(-1, x0_t.shape[2])
h_steps = h_steps.contiguous().view(-1, h.shape[1])
dzs = integrand(X_steps, h_steps)
if inv_f:
dzs = 1/integrand(X_steps, h_steps)
else:
dzs = integrand(X_steps, h_steps)
dzs = dzs.view(xT_t.shape[0], nb_steps+1, -1)
dzs = dzs*cc_weights.unsqueeze(0).expand(dzs.shape)
z_est = dzs.sum(1)
Expand All @@ -59,14 +65,18 @@ def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=
h_steps = h_steps.contiguous().view(-1, h.shape[1])
x_tot_steps = x_tot_steps.contiguous().view(-1, x_tot.shape[1])

g_param, g_h = computeIntegrand(X_steps, h_steps, integrand, x_tot_steps, nb_steps+1)
g_param, g_h = computeIntegrand(X_steps, h_steps, integrand, x_tot_steps, nb_steps+1, inv_f=inv_f)
return g_param, g_h


def computeIntegrand(x, h, integrand, x_tot, nb_steps):
def computeIntegrand(x, h, integrand, x_tot, nb_steps, inv_f=False):
h.requires_grad_(True)
with torch.enable_grad():
f = integrand.forward(x, h)
if inv_f:
f = 1/integrand.forward(x, h)
else:
f = integrand.forward(x, h)

g_param = _flatten(torch.autograd.grad(f, integrand.parameters(), x_tot, create_graph=True, retain_graph=True))
g_h = _flatten(torch.autograd.grad(f, h, x_tot))

Expand All @@ -76,12 +86,13 @@ def computeIntegrand(x, h, integrand, x_tot, nb_steps):
class ParallelNeuralIntegral(torch.autograd.Function):

@staticmethod
def forward(ctx, x0, x, integrand, flat_params, h, nb_steps=20):
def forward(ctx, x0, x, integrand, flat_params, h, nb_steps=20, inv_f=False):
with torch.no_grad():
x_tot = integrate(x0, nb_steps, (x - x0)/nb_steps, integrand, h, False)
x_tot = integrate(x0, nb_steps, (x - x0)/nb_steps, integrand, h, False, inv_f=inv_f)
# Save for backward
ctx.integrand = integrand
ctx.nb_steps = nb_steps
ctx.inv_f = inv_f
ctx.save_for_backward(x0.clone(), x.clone(), h)
return x_tot

Expand All @@ -90,7 +101,8 @@ def backward(ctx, grad_output):
x0, x, h = ctx.saved_tensors
integrand = ctx.integrand
nb_steps = ctx.nb_steps
integrand_grad, h_grad = integrate(x0, nb_steps, x/nb_steps, integrand, h, True, grad_output)
inv_f = ctx.inv_f
integrand_grad, h_grad = integrate(x0, nb_steps, x/nb_steps, integrand, h, True, grad_output, inv_f)
x_grad = integrand(x, h)
x0_grad = integrand(x0, h)
# Leibniz formula
Expand Down

0 comments on commit 9e489b5

Please sign in to comment.