From ddb6f1026e8e724bc6a811ac5fb2c1afd45b85ac Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Wed, 21 Aug 2019 17:55:02 -0700 Subject: [PATCH 1/7] Split osqpth out into Module/Function, add some automatic batch detection. --- osqpth/osqpth.py | 149 ++++++++++++++++++++++++++++++----------------- 1 file changed, 95 insertions(+), 54 deletions(-) diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index d7cf83a..58980c3 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -1,4 +1,5 @@ import torch +from torch.nn import Module from torch.autograd import Function import osqp import numpy as np @@ -6,15 +7,11 @@ import scipy.sparse.linalg as sla from .util import to_numpy - -class OSQP(Function): - def __init__(self, - P_idx, P_shape, - A_idx, A_shape, - eps_rel=1e-05, - eps_abs=1e-05, - verbose=False, +class OSQP(Module): + def __init__(self, P_idx, P_shape, A_idx, A_shape, + eps_rel=1e-5, eps_abs=1e-5, verbose=False, max_iter=10000): + super().__init__() self.eps_abs = eps_abs self.eps_rel = eps_rel self.verbose = verbose @@ -22,9 +19,22 @@ def __init__(self, self.P_idx, self.P_shape = P_idx, P_shape self.A_idx, self.A_shape = A_idx, A_shape - # TODO: Perform OSQP Setup first to allocate memory? - def forward(self, P_val, q_val, A_val, l_val, u_val): + return _OSQP.apply( + P_val, q_val, A_val, l_val, u_val, + self.P_idx, self.P_shape, + self.A_idx, self.A_shape, + self.eps_rel, self.eps_abs, + self.verbose, self.max_iter + ) + + +class _OSQP(Function): + @staticmethod + def forward(ctx, P_val, q_val, A_val, l_val, u_val, + A_idx, A_shape, P_idx, P_shape, + eps_rel=1e-5, eps_abs=1e-5, + verbose=False, max_iter=10000): """Solve a batch of QPs using OSQP. This function solves a batch of QPs, each optimizing over @@ -64,9 +74,26 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): """ + ctx.eps_abs = eps_abs + ctx.eps_rel = eps_rel + ctx.verbose = verbose + ctx.max_iter = max_iter + ctx.P_idx, ctx.P_shape = P_idx, P_shape + ctx.A_idx, ctx.A_shape = A_idx, A_shape + + params = [P_val, q_val, A_val, l_val, u_val] + + for p in params: + assert p.ndimension() <= 2, 'Unexpected number of dimensions' + # Convert batches to sparse matrices/vectors - self.n_batch = P_val.size(0) if len(P_val.size()) > 1 else 1 - self.m, self.n = self.A_shape # Problem size + batch_mode = np.all([t.ndimension() == 1 for t in params]) + if batch_mode: + ctx.n_batch = 1 + else: + batch_sizes = [t.size(0) if t.ndimension() == 2 else 1 for t in params] + ctx.n_batch = max(batch_sizes) + ctx.m, ctx.n = ctx.A_shape # Problem size dtype = P_val.dtype device = P_val.device @@ -75,28 +102,29 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): # TODO (Bart): create CSC matrix during initialization. Then # just reassign the mat.data vector with A_val and P_val - if self.n_batch == 1: - # Create lists to make the code below work - # TODO (Bart): Find a better way to do this - P_val, q_val, A_val, l_val, u_val = [P_val], [q_val], [A_val], [l_val], [u_val] + for i, p in enumerate(params): + if p.ndimension() == 1: + params[i] = p.unsqueeze(0).expand(ctx.n_batch, p.size(0)) + + [P_val, q_val, A_val, l_val, u_val] = params - P = [spa.csc_matrix((to_numpy(P_val[i]), self.P_idx), shape=self.P_shape) - for i in range(self.n_batch)] - q = [to_numpy(q_val[i]) for i in range(self.n_batch)] - A = [spa.csc_matrix((to_numpy(A_val[i]), self.A_idx), shape=self.A_shape) - for i in range(self.n_batch)] - l = [to_numpy(l_val[i]) for i in range(self.n_batch)] - u = [to_numpy(u_val[i]) for i in range(self.n_batch)] + P = [spa.csc_matrix((to_numpy(P_val[i]), ctx.P_idx), shape=ctx.P_shape) + for i in range(ctx.n_batch)] + q = [to_numpy(q_val[i]) for i in range(ctx.n_batch)] + A = [spa.csc_matrix((to_numpy(A_val[i]), ctx.A_idx), shape=ctx.A_shape) + for i in range(ctx.n_batch)] + l = [to_numpy(l_val[i]) for i in range(ctx.n_batch)] + u = [to_numpy(u_val[i]) for i in range(ctx.n_batch)] # Perform forward step solving the QPs - x_torch = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device) + x_torch = torch.zeros((ctx.n_batch, ctx.n), dtype=dtype, device=device) x, y, z = [], [], [] - for i in range(self.n_batch): + for i in range(ctx.n_batch): # Solve QP # TODO: Cache solver object in between m = osqp.OSQP() - m.setup(P[i], q[i], A[i], l[i], u[i], verbose=self.verbose) + m.setup(P[i], q[i], A[i], l[i], u[i], verbose=ctx.verbose) result = m.solve() status = result.info.status if status != 'solved': @@ -112,34 +140,40 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): x_torch[i] = torch.from_numpy(result.x) # Save stuff for backpropagation - self.backward_vars = (P, q, A, l, u, x, y, z) + ctx.backward_vars = (P, q, A, l, u, x, y, z) # Return solutions + if not batch_mode: + x_torch = x_torch.squeeze(0) return x_torch - def backward(self, dl_dx_val): - + @staticmethod + def backward(ctx, dl_dx_val): dtype = dl_dx_val.dtype device = dl_dx_val.device + batch_mode = dl_dx_val.ndimension() == 2 + if not batch_mode: + dl_dx_val = dl_dx_val.unsqueeze(0) + # Convert dl_dx to numpy - dl_dx = to_numpy(dl_dx_val).squeeze() + dl_dx = to_numpy(dl_dx_val) # Extract data from forward pass - P, q, A, l, u, x, y, z = self.backward_vars + P, q, A, l, u, x, y, z = ctx.backward_vars # Convert to torch tensors - nnz_P = len(self.P_idx[0]) - nnz_A = len(self.A_idx[0]) - dP = torch.zeros((self.n_batch, nnz_P), dtype=dtype, device=device) - dq = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device) - dA = torch.zeros((self.n_batch, nnz_A), dtype=dtype, device=device) - dl = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device) - du = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device) + nnz_P = len(ctx.P_idx[0]) + nnz_A = len(ctx.A_idx[0]) + dP = torch.zeros((ctx.n_batch, nnz_P), dtype=dtype, device=device) + dq = torch.zeros((ctx.n_batch, ctx.n), dtype=dtype, device=device) + dA = torch.zeros((ctx.n_batch, nnz_A), dtype=dtype, device=device) + dl = torch.zeros((ctx.n_batch, ctx.m), dtype=dtype, device=device) + du = torch.zeros((ctx.n_batch, ctx.m), dtype=dtype, device=device) # TODO: Improve this, reuse OSQP, port it in C - for i in range(self.n_batch): + for i in range(ctx.n_batch): # Construct linear system # Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717 # Guess which linear constraints are lower-active, upper-active, free @@ -154,33 +188,40 @@ def backward(self, dl_dx_val): # Form KKT linear system KKT = spa.vstack([spa.hstack([P[i], A_red.T]), spa.hstack([A_red, spa.csc_matrix((n_low + n_upp, n_low + n_upp))])]) - rhs = np.hstack([dl_dx.squeeze(), np.zeros(n_low + n_upp)]) + rhs = np.hstack([dl_dx[i], np.zeros(n_low + n_upp)]) # Get solution r_sol = sla.spsolve(KKT, rhs) - r_x = r_sol[:self.n] - r_yl = r_sol[self.n:self.n + n_low] - r_yu = r_sol[self.n + n_low:] - r_y = np.zeros(self.m) + + r_x = r_sol[:ctx.n] + r_yl = r_sol[ctx.n:ctx.n + n_low] + r_yu = r_sol[ctx.n + n_low:] + r_y = np.zeros(ctx.m) r_y[ind_low] = r_yl r_y[ind_upp] = r_yu # Extract derivatives - rows, cols = self.P_idx + rows, cols = ctx.P_idx values = -.5 * (r_x[rows] * x[i][cols] + r_x[cols] * x[i][rows]) dP[i] = torch.from_numpy(values) - rows, cols = self.A_idx + rows, cols = ctx.A_idx values = -(y[i][rows] * r_x[cols] + r_y[rows] * x[i][cols]) dA[i] = torch.from_numpy(values) dq[i] = torch.from_numpy(-r_x) - dl[i] = torch.tensor( - [r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 - for j in range(self.m)]) - du[i] = torch.tensor( - [r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 - for j in range(self.m)]) + t = np.hstack([r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 + for j in range(ctx.m)]) + dl[i] = torch.tensor(t) + t = np.hstack([r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 + for j in range(ctx.m)]) + du[i] = torch.tensor(t) + + grads = [dP, dq, dA, dl, du] + + if not batch_mode: + for i, g in enumerate(grads): + grads[i] = g.squeeze() - grads = (dP, dq, dA, dl, du) + grads += [None]*8 - return grads + return tuple(grads) From 5a340e5bd19a164fd3b9d6d0621c2a20371693d3 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Wed, 21 Aug 2019 18:05:24 -0700 Subject: [PATCH 2/7] Stub in differentiation modes --- osqpth/osqpth.py | 55 ++++++++++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index 58980c3..e4520de 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -5,12 +5,17 @@ import numpy as np import scipy.sparse as spa import scipy.sparse.linalg as sla + +from enum import IntEnum + from .util import to_numpy +DiffModes = IntEnum('DiffModes', 'ACTIVE FULL') + class OSQP(Module): def __init__(self, P_idx, P_shape, A_idx, A_shape, eps_rel=1e-5, eps_abs=1e-5, verbose=False, - max_iter=10000): + max_iter=10000, diff_mode=DiffModes.ACTIVE): super().__init__() self.eps_abs = eps_abs self.eps_rel = eps_rel @@ -18,6 +23,7 @@ def __init__(self, P_idx, P_shape, A_idx, A_shape, self.max_iter = max_iter self.P_idx, self.P_shape = P_idx, P_shape self.A_idx, self.A_shape = A_idx, A_shape + self.diff_mode = diff_mode def forward(self, P_val, q_val, A_val, l_val, u_val): return _OSQP.apply( @@ -25,7 +31,8 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): self.P_idx, self.P_shape, self.A_idx, self.A_shape, self.eps_rel, self.eps_abs, - self.verbose, self.max_iter + self.verbose, self.max_iter, + self.diff_mode, ) @@ -33,8 +40,7 @@ class _OSQP(Function): @staticmethod def forward(ctx, P_val, q_val, A_val, l_val, u_val, A_idx, A_shape, P_idx, P_shape, - eps_rel=1e-5, eps_abs=1e-5, - verbose=False, max_iter=10000): + eps_rel, eps_abs, verbose, max_iter, diff_mode): """Solve a batch of QPs using OSQP. This function solves a batch of QPs, each optimizing over @@ -80,6 +86,7 @@ def forward(ctx, P_val, q_val, A_val, l_val, u_val, ctx.max_iter = max_iter ctx.P_idx, ctx.P_shape = P_idx, P_shape ctx.A_idx, ctx.A_shape = A_idx, A_shape + ctx.diff_mode = diff_mode params = [P_val, q_val, A_val, l_val, u_val] @@ -175,23 +182,29 @@ def backward(ctx, dl_dx_val): for i in range(ctx.n_batch): # Construct linear system - # Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717 - # Guess which linear constraints are lower-active, upper-active, free - ind_low = np.where(z[i] - l[i] < - y[i])[0] - ind_upp = np.where(u[i] - z[i] < y[i])[0] - n_low = len(ind_low) - n_upp = len(ind_upp) - - # Form A_red from the assumed active constraints - A_red = spa.vstack([A[i][ind_low], A[i][ind_upp]]) - - # Form KKT linear system - KKT = spa.vstack([spa.hstack([P[i], A_red.T]), - spa.hstack([A_red, spa.csc_matrix((n_low + n_upp, n_low + n_upp))])]) - rhs = np.hstack([dl_dx[i], np.zeros(n_low + n_upp)]) - # Get solution - r_sol = sla.spsolve(KKT, rhs) + if ctx.diff_mode == DiffModes.ACTIVE: + # Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717 + # Guess which linear constraints are lower-active, upper-active, free + ind_low = np.where(z[i] - l[i] < - y[i])[0] + ind_upp = np.where(u[i] - z[i] < y[i])[0] + n_low = len(ind_low) + n_upp = len(ind_upp) + + # Form A_red from the assumed active constraints + A_red = spa.vstack([A[i][ind_low], A[i][ind_upp]]) + + # Form KKT linear system + KKT = spa.vstack([spa.hstack([P[i], A_red.T]), + spa.hstack([A_red, spa.csc_matrix((n_low + n_upp, n_low + n_upp))])]) + rhs = np.hstack([dl_dx[i], np.zeros(n_low + n_upp)]) + + # Get solution + r_sol = sla.spsolve(KKT, rhs) + elif ctx.diff_mode == DiffModes.FULL: + raise NotImplementedError + else: + raise RuntimeError(f"Unrecognized differentiation mode") r_x = r_sol[:ctx.n] r_yl = r_sol[ctx.n:ctx.n + n_low] @@ -222,6 +235,6 @@ def backward(ctx, dl_dx_val): for i, g in enumerate(grads): grads[i] = g.squeeze() - grads += [None]*8 + grads += [None]*9 return tuple(grads) From 55f967ee3ff71d1fd79607c5e95bc3aa254b1785 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Wed, 21 Aug 2019 18:37:43 -0700 Subject: [PATCH 3/7] Add in full KKT system diff mode. --- osqpth/osqpth.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index e4520de..fc81e20 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -39,7 +39,7 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): class _OSQP(Function): @staticmethod def forward(ctx, P_val, q_val, A_val, l_val, u_val, - A_idx, A_shape, P_idx, P_shape, + P_idx, P_shape, A_idx, A_shape, eps_rel, eps_abs, verbose, max_iter, diff_mode): """Solve a batch of QPs using OSQP. @@ -201,17 +201,37 @@ def backward(ctx, dl_dx_val): # Get solution r_sol = sla.spsolve(KKT, rhs) + r_x = r_sol[:ctx.n] + r_yl = r_sol[ctx.n:ctx.n + n_low] + r_yu = r_sol[ctx.n + n_low:] + r_y = np.zeros(ctx.m) + r_y[ind_low] = r_yl + r_y[ind_upp] = r_yu + + t = np.hstack([r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 + for j in range(ctx.m)]) + dl[i] = torch.tensor(t) + t = np.hstack([r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 + for j in range(ctx.m)]) + du[i] = torch.tensor(t) elif ctx.diff_mode == DiffModes.FULL: - raise NotImplementedError + # TODO: Add in kkt_eps as an option? + kkt_eps = 1e-6 + KKT = spa.vstack([spa.hstack([P[i], A[i].T.dot(spa.diags(y[i]))]), + spa.hstack([A[i], -kkt_eps*spa.eye(ctx.m)])]) + rhs = np.hstack([dl_dx[i], np.zeros(ctx.m)]) + + # Get solution + r_sol = sla.spsolve(KKT, rhs) + r_x = r_sol[:ctx.n] + r_y = r_sol[ctx.n:] * y[i] + + J = y[i] < 0. + dl[i][J] = torch.from_numpy(r_y[J]) + du[i][~J] = torch.from_numpy(r_y[~J]) else: raise RuntimeError(f"Unrecognized differentiation mode") - r_x = r_sol[:ctx.n] - r_yl = r_sol[ctx.n:ctx.n + n_low] - r_yu = r_sol[ctx.n + n_low:] - r_y = np.zeros(ctx.m) - r_y[ind_low] = r_yl - r_y[ind_upp] = r_yu # Extract derivatives rows, cols = ctx.P_idx @@ -222,12 +242,6 @@ def backward(ctx, dl_dx_val): values = -(y[i][rows] * r_x[cols] + r_y[rows] * x[i][cols]) dA[i] = torch.from_numpy(values) dq[i] = torch.from_numpy(-r_x) - t = np.hstack([r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 - for j in range(ctx.m)]) - dl[i] = torch.tensor(t) - t = np.hstack([r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 - for j in range(ctx.m)]) - du[i] = torch.tensor(t) grads = [dP, dq, dA, dl, du] From 99cd961c0c924831d478b03c514f8f5072441785 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Wed, 21 Aug 2019 18:39:15 -0700 Subject: [PATCH 4/7] Fix device issue --- osqpth/osqpth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index fc81e20..7e47566 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -227,8 +227,8 @@ def backward(ctx, dl_dx_val): r_y = r_sol[ctx.n:] * y[i] J = y[i] < 0. - dl[i][J] = torch.from_numpy(r_y[J]) - du[i][~J] = torch.from_numpy(r_y[~J]) + dl[i][J] = torch.from_numpy(r_y[J]).to(dl.device) + du[i][~J] = torch.from_numpy(r_y[~J]).to(du.device) else: raise RuntimeError(f"Unrecognized differentiation mode") From 85ed03873233ffdc6eeb1201e05599484b997d5c Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Wed, 21 Aug 2019 18:46:08 -0700 Subject: [PATCH 5/7] Add FULL diff_mode to tests. --- tests/test_osqpth.py | 45 +++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/test_osqpth.py b/tests/test_osqpth.py index 4b85aa3..ef56f89 100755 --- a/tests/test_osqpth.py +++ b/tests/test_osqpth.py @@ -3,7 +3,7 @@ import pytest import osqp -from osqpth.osqpth import OSQP +from osqpth.osqpth import OSQP, DiffModes import numpy.random as npr import numpy as np import torch @@ -19,7 +19,8 @@ def get_grads(n_batch=1, n=10, m=3, P_scale=1., - A_scale=1., u_scale=1., l_scale=1.): + A_scale=1., u_scale=1., l_scale=1., + diff_mode=DiffModes.FULL): assert(n_batch == 1) npr.seed(1) L = np.random.randn(n, n) @@ -35,11 +36,11 @@ def get_grads(n_batch=1, n=10, m=3, P_scale=1., P, q, A, l, u, true_x = [x.astype(np.float64) for x in [P, q, A, l, u, true_x]] - grads = get_grads_torch(P, q, A, l, u, true_x) + grads = get_grads_torch(P, q, A, l, u, true_x, diff_mode) return [P, q, A, l, u, true_x], grads -def get_grads_torch(P, q, A, l, u, true_x): +def get_grads_torch(P, q, A, l, u, true_x, diff_mode): P_idx = P.nonzero() P_shape = P.shape @@ -56,7 +57,7 @@ def get_grads_torch(P, q, A, l, u, true_x): for x in [P_torch, q_torch, A_torch, l_torch, u_torch]: x.requires_grad = True - x_hats = OSQP(P_idx, P_shape, A_idx, A_shape)(P_torch, q_torch, A_torch, l_torch, u_torch) + x_hats = OSQP(P_idx, P_shape, A_idx, A_shape, diff_mode=diff_mode)(P_torch, q_torch, A_torch, l_torch, u_torch) dl_dxhat = x_hats.data - true_x_torch x_hats.backward(dl_dxhat) @@ -67,19 +68,21 @@ def get_grads_torch(P, q, A, l, u, true_x): def test_dl_dp(): n, m = 5, 5 - [P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads( - n=n, m=m, P_scale=100., A_scale=100.) - - def f(q): - m = osqp.OSQP() - m.setup(P, q, A, l, u, verbose=verbose) - res = m.solve() - x_hat = res.x - - return 0.5 * np.sum(np.square(x_hat - true_x)) - - dq_fd = nd.Gradient(f)(q) - if verbose: - print('dq_fd: ', np.round(dq_fd, decimals=4)) - print('dq: ', np.round(dq, decimals=4)) - npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL) + for diff_mode in DiffModes: + [P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads( + n=n, m=m, P_scale=100., A_scale=100., diff_mode=diff_mode) + print(f'--- {diff_mode.name}') + + def f(q): + m = osqp.OSQP() + m.setup(P, q, A, l, u, verbose=False) + res = m.solve() + x_hat = res.x + + return 0.5 * np.sum(np.square(x_hat - true_x)) + + dq_fd = nd.Gradient(f)(q) + if verbose: + print('dq_fd: ', np.round(dq_fd, decimals=4)) + print('dq: ', np.round(dq, decimals=4)) + npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL) From a4ea32ffa500e6086974538db367d778f7248e41 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Wed, 21 Aug 2019 19:04:21 -0700 Subject: [PATCH 6/7] Use lsqr for the solves --- osqpth/osqpth.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index 7e47566..7bff67f 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -200,7 +200,9 @@ def backward(ctx, dl_dx_val): rhs = np.hstack([dl_dx[i], np.zeros(n_low + n_upp)]) # Get solution - r_sol = sla.spsolve(KKT, rhs) + # r_sol = sla.spsolve(KKT, rhs) + r_sol = sla.lsqr(KKT, rhs)[0] + r_x = r_sol[:ctx.n] r_yl = r_sol[ctx.n:ctx.n + n_low] r_yu = r_sol[ctx.n + n_low:] @@ -215,14 +217,14 @@ def backward(ctx, dl_dx_val): for j in range(ctx.m)]) du[i] = torch.tensor(t) elif ctx.diff_mode == DiffModes.FULL: - # TODO: Add in kkt_eps as an option? - kkt_eps = 1e-6 KKT = spa.vstack([spa.hstack([P[i], A[i].T.dot(spa.diags(y[i]))]), - spa.hstack([A[i], -kkt_eps*spa.eye(ctx.m)])]) + spa.hstack([A[i], spa.csc_matrix((ctx.m, ctx.m))])]) rhs = np.hstack([dl_dx[i], np.zeros(ctx.m)]) # Get solution - r_sol = sla.spsolve(KKT, rhs) + # r_sol = sla.spsolve(KKT, rhs) + r_sol = sla.lsqr(KKT, rhs)[0] + r_x = r_sol[:ctx.n] r_y = r_sol[ctx.n:] * y[i] From 47023069b459e266934376af0bc0d013ab68af2e Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Thu, 22 Aug 2019 07:57:38 -0700 Subject: [PATCH 7/7] Full bw pass: Add KKT_22 --- osqpth/osqpth.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index 7bff67f..72042b9 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -217,18 +217,24 @@ def backward(ctx, dl_dx_val): for j in range(ctx.m)]) du[i] = torch.tensor(t) elif ctx.diff_mode == DiffModes.FULL: - KKT = spa.vstack([spa.hstack([P[i], A[i].T.dot(spa.diags(y[i]))]), - spa.hstack([A[i], spa.csc_matrix((ctx.m, ctx.m))])]) + KKT_22 = np.zeros(ctx.m) + J = y[i] < 0. + KKT_22[J] = (A[i].dot(x[i]) - l[i])[J] + KKT_22[~J] = (A[i].dot(x[i]) - u[i])[~J] + + # TODO: Better handle when the bounds are \pm\infty. + KKT_22[np.isinf(KKT_22)] = 0. + + KKT_T = spa.vstack([spa.hstack([P[i], A[i].T.dot(spa.diags(y[i]))]), + spa.hstack([A[i], spa.diags(KKT_22)])]) rhs = np.hstack([dl_dx[i], np.zeros(ctx.m)]) # Get solution - # r_sol = sla.spsolve(KKT, rhs) - r_sol = sla.lsqr(KKT, rhs)[0] + r_sol = sla.lsqr(KKT_T, rhs)[0] r_x = r_sol[:ctx.n] r_y = r_sol[ctx.n:] * y[i] - J = y[i] < 0. dl[i][J] = torch.from_numpy(r_y[J]).to(dl.device) du[i][~J] = torch.from_numpy(r_y[~J]).to(du.device) else: