diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index d7cf83a..72042b9 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -1,30 +1,46 @@ import torch +from torch.nn import Module from torch.autograd import Function import osqp import numpy as np import scipy.sparse as spa import scipy.sparse.linalg as sla + +from enum import IntEnum + from .util import to_numpy +DiffModes = IntEnum('DiffModes', 'ACTIVE FULL') -class OSQP(Function): - def __init__(self, - P_idx, P_shape, - A_idx, A_shape, - eps_rel=1e-05, - eps_abs=1e-05, - verbose=False, - max_iter=10000): +class OSQP(Module): + def __init__(self, P_idx, P_shape, A_idx, A_shape, + eps_rel=1e-5, eps_abs=1e-5, verbose=False, + max_iter=10000, diff_mode=DiffModes.ACTIVE): + super().__init__() self.eps_abs = eps_abs self.eps_rel = eps_rel self.verbose = verbose self.max_iter = max_iter self.P_idx, self.P_shape = P_idx, P_shape self.A_idx, self.A_shape = A_idx, A_shape - - # TODO: Perform OSQP Setup first to allocate memory? + self.diff_mode = diff_mode def forward(self, P_val, q_val, A_val, l_val, u_val): + return _OSQP.apply( + P_val, q_val, A_val, l_val, u_val, + self.P_idx, self.P_shape, + self.A_idx, self.A_shape, + self.eps_rel, self.eps_abs, + self.verbose, self.max_iter, + self.diff_mode, + ) + + +class _OSQP(Function): + @staticmethod + def forward(ctx, P_val, q_val, A_val, l_val, u_val, + P_idx, P_shape, A_idx, A_shape, + eps_rel, eps_abs, verbose, max_iter, diff_mode): """Solve a batch of QPs using OSQP. This function solves a batch of QPs, each optimizing over @@ -64,9 +80,27 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): """ + ctx.eps_abs = eps_abs + ctx.eps_rel = eps_rel + ctx.verbose = verbose + ctx.max_iter = max_iter + ctx.P_idx, ctx.P_shape = P_idx, P_shape + ctx.A_idx, ctx.A_shape = A_idx, A_shape + ctx.diff_mode = diff_mode + + params = [P_val, q_val, A_val, l_val, u_val] + + for p in params: + assert p.ndimension() <= 2, 'Unexpected number of dimensions' + # Convert batches to sparse matrices/vectors - self.n_batch = P_val.size(0) if len(P_val.size()) > 1 else 1 - self.m, self.n = self.A_shape # Problem size + batch_mode = np.all([t.ndimension() == 1 for t in params]) + if batch_mode: + ctx.n_batch = 1 + else: + batch_sizes = [t.size(0) if t.ndimension() == 2 else 1 for t in params] + ctx.n_batch = max(batch_sizes) + ctx.m, ctx.n = ctx.A_shape # Problem size dtype = P_val.dtype device = P_val.device @@ -75,28 +109,29 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): # TODO (Bart): create CSC matrix during initialization. Then # just reassign the mat.data vector with A_val and P_val - if self.n_batch == 1: - # Create lists to make the code below work - # TODO (Bart): Find a better way to do this - P_val, q_val, A_val, l_val, u_val = [P_val], [q_val], [A_val], [l_val], [u_val] + for i, p in enumerate(params): + if p.ndimension() == 1: + params[i] = p.unsqueeze(0).expand(ctx.n_batch, p.size(0)) + + [P_val, q_val, A_val, l_val, u_val] = params - P = [spa.csc_matrix((to_numpy(P_val[i]), self.P_idx), shape=self.P_shape) - for i in range(self.n_batch)] - q = [to_numpy(q_val[i]) for i in range(self.n_batch)] - A = [spa.csc_matrix((to_numpy(A_val[i]), self.A_idx), shape=self.A_shape) - for i in range(self.n_batch)] - l = [to_numpy(l_val[i]) for i in range(self.n_batch)] - u = [to_numpy(u_val[i]) for i in range(self.n_batch)] + P = [spa.csc_matrix((to_numpy(P_val[i]), ctx.P_idx), shape=ctx.P_shape) + for i in range(ctx.n_batch)] + q = [to_numpy(q_val[i]) for i in range(ctx.n_batch)] + A = [spa.csc_matrix((to_numpy(A_val[i]), ctx.A_idx), shape=ctx.A_shape) + for i in range(ctx.n_batch)] + l = [to_numpy(l_val[i]) for i in range(ctx.n_batch)] + u = [to_numpy(u_val[i]) for i in range(ctx.n_batch)] # Perform forward step solving the QPs - x_torch = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device) + x_torch = torch.zeros((ctx.n_batch, ctx.n), dtype=dtype, device=device) x, y, z = [], [], [] - for i in range(self.n_batch): + for i in range(ctx.n_batch): # Solve QP # TODO: Cache solver object in between m = osqp.OSQP() - m.setup(P[i], q[i], A[i], l[i], u[i], verbose=self.verbose) + m.setup(P[i], q[i], A[i], l[i], u[i], verbose=ctx.verbose) result = m.solve() status = result.info.status if status != 'solved': @@ -112,75 +147,116 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): x_torch[i] = torch.from_numpy(result.x) # Save stuff for backpropagation - self.backward_vars = (P, q, A, l, u, x, y, z) + ctx.backward_vars = (P, q, A, l, u, x, y, z) # Return solutions + if not batch_mode: + x_torch = x_torch.squeeze(0) return x_torch - def backward(self, dl_dx_val): - + @staticmethod + def backward(ctx, dl_dx_val): dtype = dl_dx_val.dtype device = dl_dx_val.device + batch_mode = dl_dx_val.ndimension() == 2 + if not batch_mode: + dl_dx_val = dl_dx_val.unsqueeze(0) + # Convert dl_dx to numpy - dl_dx = to_numpy(dl_dx_val).squeeze() + dl_dx = to_numpy(dl_dx_val) # Extract data from forward pass - P, q, A, l, u, x, y, z = self.backward_vars + P, q, A, l, u, x, y, z = ctx.backward_vars # Convert to torch tensors - nnz_P = len(self.P_idx[0]) - nnz_A = len(self.A_idx[0]) - dP = torch.zeros((self.n_batch, nnz_P), dtype=dtype, device=device) - dq = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device) - dA = torch.zeros((self.n_batch, nnz_A), dtype=dtype, device=device) - dl = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device) - du = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device) + nnz_P = len(ctx.P_idx[0]) + nnz_A = len(ctx.A_idx[0]) + dP = torch.zeros((ctx.n_batch, nnz_P), dtype=dtype, device=device) + dq = torch.zeros((ctx.n_batch, ctx.n), dtype=dtype, device=device) + dA = torch.zeros((ctx.n_batch, nnz_A), dtype=dtype, device=device) + dl = torch.zeros((ctx.n_batch, ctx.m), dtype=dtype, device=device) + du = torch.zeros((ctx.n_batch, ctx.m), dtype=dtype, device=device) # TODO: Improve this, reuse OSQP, port it in C - for i in range(self.n_batch): + for i in range(ctx.n_batch): # Construct linear system - # Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717 - # Guess which linear constraints are lower-active, upper-active, free - ind_low = np.where(z[i] - l[i] < - y[i])[0] - ind_upp = np.where(u[i] - z[i] < y[i])[0] - n_low = len(ind_low) - n_upp = len(ind_upp) - - # Form A_red from the assumed active constraints - A_red = spa.vstack([A[i][ind_low], A[i][ind_upp]]) - - # Form KKT linear system - KKT = spa.vstack([spa.hstack([P[i], A_red.T]), - spa.hstack([A_red, spa.csc_matrix((n_low + n_upp, n_low + n_upp))])]) - rhs = np.hstack([dl_dx.squeeze(), np.zeros(n_low + n_upp)]) - - # Get solution - r_sol = sla.spsolve(KKT, rhs) - r_x = r_sol[:self.n] - r_yl = r_sol[self.n:self.n + n_low] - r_yu = r_sol[self.n + n_low:] - r_y = np.zeros(self.m) - r_y[ind_low] = r_yl - r_y[ind_upp] = r_yu + + if ctx.diff_mode == DiffModes.ACTIVE: + # Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717 + # Guess which linear constraints are lower-active, upper-active, free + ind_low = np.where(z[i] - l[i] < - y[i])[0] + ind_upp = np.where(u[i] - z[i] < y[i])[0] + n_low = len(ind_low) + n_upp = len(ind_upp) + + # Form A_red from the assumed active constraints + A_red = spa.vstack([A[i][ind_low], A[i][ind_upp]]) + + # Form KKT linear system + KKT = spa.vstack([spa.hstack([P[i], A_red.T]), + spa.hstack([A_red, spa.csc_matrix((n_low + n_upp, n_low + n_upp))])]) + rhs = np.hstack([dl_dx[i], np.zeros(n_low + n_upp)]) + + # Get solution + # r_sol = sla.spsolve(KKT, rhs) + r_sol = sla.lsqr(KKT, rhs)[0] + + r_x = r_sol[:ctx.n] + r_yl = r_sol[ctx.n:ctx.n + n_low] + r_yu = r_sol[ctx.n + n_low:] + r_y = np.zeros(ctx.m) + r_y[ind_low] = r_yl + r_y[ind_upp] = r_yu + + t = np.hstack([r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 + for j in range(ctx.m)]) + dl[i] = torch.tensor(t) + t = np.hstack([r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 + for j in range(ctx.m)]) + du[i] = torch.tensor(t) + elif ctx.diff_mode == DiffModes.FULL: + KKT_22 = np.zeros(ctx.m) + J = y[i] < 0. + KKT_22[J] = (A[i].dot(x[i]) - l[i])[J] + KKT_22[~J] = (A[i].dot(x[i]) - u[i])[~J] + + # TODO: Better handle when the bounds are \pm\infty. + KKT_22[np.isinf(KKT_22)] = 0. + + KKT_T = spa.vstack([spa.hstack([P[i], A[i].T.dot(spa.diags(y[i]))]), + spa.hstack([A[i], spa.diags(KKT_22)])]) + rhs = np.hstack([dl_dx[i], np.zeros(ctx.m)]) + + # Get solution + r_sol = sla.lsqr(KKT_T, rhs)[0] + + r_x = r_sol[:ctx.n] + r_y = r_sol[ctx.n:] * y[i] + + dl[i][J] = torch.from_numpy(r_y[J]).to(dl.device) + du[i][~J] = torch.from_numpy(r_y[~J]).to(du.device) + else: + raise RuntimeError(f"Unrecognized differentiation mode") + # Extract derivatives - rows, cols = self.P_idx + rows, cols = ctx.P_idx values = -.5 * (r_x[rows] * x[i][cols] + r_x[cols] * x[i][rows]) dP[i] = torch.from_numpy(values) - rows, cols = self.A_idx + rows, cols = ctx.A_idx values = -(y[i][rows] * r_x[cols] + r_y[rows] * x[i][cols]) dA[i] = torch.from_numpy(values) dq[i] = torch.from_numpy(-r_x) - dl[i] = torch.tensor( - [r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 - for j in range(self.m)]) - du[i] = torch.tensor( - [r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 - for j in range(self.m)]) - grads = (dP, dq, dA, dl, du) + grads = [dP, dq, dA, dl, du] + + if not batch_mode: + for i, g in enumerate(grads): + grads[i] = g.squeeze() + + grads += [None]*9 - return grads + return tuple(grads) diff --git a/tests/test_osqpth.py b/tests/test_osqpth.py index 4b85aa3..ef56f89 100755 --- a/tests/test_osqpth.py +++ b/tests/test_osqpth.py @@ -3,7 +3,7 @@ import pytest import osqp -from osqpth.osqpth import OSQP +from osqpth.osqpth import OSQP, DiffModes import numpy.random as npr import numpy as np import torch @@ -19,7 +19,8 @@ def get_grads(n_batch=1, n=10, m=3, P_scale=1., - A_scale=1., u_scale=1., l_scale=1.): + A_scale=1., u_scale=1., l_scale=1., + diff_mode=DiffModes.FULL): assert(n_batch == 1) npr.seed(1) L = np.random.randn(n, n) @@ -35,11 +36,11 @@ def get_grads(n_batch=1, n=10, m=3, P_scale=1., P, q, A, l, u, true_x = [x.astype(np.float64) for x in [P, q, A, l, u, true_x]] - grads = get_grads_torch(P, q, A, l, u, true_x) + grads = get_grads_torch(P, q, A, l, u, true_x, diff_mode) return [P, q, A, l, u, true_x], grads -def get_grads_torch(P, q, A, l, u, true_x): +def get_grads_torch(P, q, A, l, u, true_x, diff_mode): P_idx = P.nonzero() P_shape = P.shape @@ -56,7 +57,7 @@ def get_grads_torch(P, q, A, l, u, true_x): for x in [P_torch, q_torch, A_torch, l_torch, u_torch]: x.requires_grad = True - x_hats = OSQP(P_idx, P_shape, A_idx, A_shape)(P_torch, q_torch, A_torch, l_torch, u_torch) + x_hats = OSQP(P_idx, P_shape, A_idx, A_shape, diff_mode=diff_mode)(P_torch, q_torch, A_torch, l_torch, u_torch) dl_dxhat = x_hats.data - true_x_torch x_hats.backward(dl_dxhat) @@ -67,19 +68,21 @@ def get_grads_torch(P, q, A, l, u, true_x): def test_dl_dp(): n, m = 5, 5 - [P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads( - n=n, m=m, P_scale=100., A_scale=100.) - - def f(q): - m = osqp.OSQP() - m.setup(P, q, A, l, u, verbose=verbose) - res = m.solve() - x_hat = res.x - - return 0.5 * np.sum(np.square(x_hat - true_x)) - - dq_fd = nd.Gradient(f)(q) - if verbose: - print('dq_fd: ', np.round(dq_fd, decimals=4)) - print('dq: ', np.round(dq, decimals=4)) - npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL) + for diff_mode in DiffModes: + [P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads( + n=n, m=m, P_scale=100., A_scale=100., diff_mode=diff_mode) + print(f'--- {diff_mode.name}') + + def f(q): + m = osqp.OSQP() + m.setup(P, q, A, l, u, verbose=False) + res = m.solve() + x_hat = res.x + + return 0.5 * np.sum(np.square(x_hat - true_x)) + + dq_fd = nd.Gradient(f)(q) + if verbose: + print('dq_fd: ', np.round(dq_fd, decimals=4)) + print('dq: ', np.round(dq, decimals=4)) + npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL)