Skip to content

Commit

Permalink
Merge pull request #5 from oxfordcontrol/2019.08.21.bda
Browse files Browse the repository at this point in the history
Add full KKT differentiation mode, split into new-style PyTorch Module/Function, automatically infer some batch sizes
  • Loading branch information
bstellato authored Aug 31, 2019
2 parents c001ec7 + 4702306 commit 007680a
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 93 deletions.
220 changes: 148 additions & 72 deletions osqpth/osqpth.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,46 @@
import torch
from torch.nn import Module
from torch.autograd import Function
import osqp
import numpy as np
import scipy.sparse as spa
import scipy.sparse.linalg as sla

from enum import IntEnum

from .util import to_numpy

DiffModes = IntEnum('DiffModes', 'ACTIVE FULL')

class OSQP(Function):
def __init__(self,
P_idx, P_shape,
A_idx, A_shape,
eps_rel=1e-05,
eps_abs=1e-05,
verbose=False,
max_iter=10000):
class OSQP(Module):
def __init__(self, P_idx, P_shape, A_idx, A_shape,
eps_rel=1e-5, eps_abs=1e-5, verbose=False,
max_iter=10000, diff_mode=DiffModes.ACTIVE):
super().__init__()
self.eps_abs = eps_abs
self.eps_rel = eps_rel
self.verbose = verbose
self.max_iter = max_iter
self.P_idx, self.P_shape = P_idx, P_shape
self.A_idx, self.A_shape = A_idx, A_shape

# TODO: Perform OSQP Setup first to allocate memory?
self.diff_mode = diff_mode

def forward(self, P_val, q_val, A_val, l_val, u_val):
return _OSQP.apply(
P_val, q_val, A_val, l_val, u_val,
self.P_idx, self.P_shape,
self.A_idx, self.A_shape,
self.eps_rel, self.eps_abs,
self.verbose, self.max_iter,
self.diff_mode,
)


class _OSQP(Function):
@staticmethod
def forward(ctx, P_val, q_val, A_val, l_val, u_val,
P_idx, P_shape, A_idx, A_shape,
eps_rel, eps_abs, verbose, max_iter, diff_mode):
"""Solve a batch of QPs using OSQP.
This function solves a batch of QPs, each optimizing over
Expand Down Expand Up @@ -64,9 +80,27 @@ def forward(self, P_val, q_val, A_val, l_val, u_val):
"""

ctx.eps_abs = eps_abs
ctx.eps_rel = eps_rel
ctx.verbose = verbose
ctx.max_iter = max_iter
ctx.P_idx, ctx.P_shape = P_idx, P_shape
ctx.A_idx, ctx.A_shape = A_idx, A_shape
ctx.diff_mode = diff_mode

params = [P_val, q_val, A_val, l_val, u_val]

for p in params:
assert p.ndimension() <= 2, 'Unexpected number of dimensions'

# Convert batches to sparse matrices/vectors
self.n_batch = P_val.size(0) if len(P_val.size()) > 1 else 1
self.m, self.n = self.A_shape # Problem size
batch_mode = np.all([t.ndimension() == 1 for t in params])
if batch_mode:
ctx.n_batch = 1
else:
batch_sizes = [t.size(0) if t.ndimension() == 2 else 1 for t in params]
ctx.n_batch = max(batch_sizes)
ctx.m, ctx.n = ctx.A_shape # Problem size

dtype = P_val.dtype
device = P_val.device
Expand All @@ -75,28 +109,29 @@ def forward(self, P_val, q_val, A_val, l_val, u_val):
# TODO (Bart): create CSC matrix during initialization. Then
# just reassign the mat.data vector with A_val and P_val

if self.n_batch == 1:
# Create lists to make the code below work
# TODO (Bart): Find a better way to do this
P_val, q_val, A_val, l_val, u_val = [P_val], [q_val], [A_val], [l_val], [u_val]
for i, p in enumerate(params):
if p.ndimension() == 1:
params[i] = p.unsqueeze(0).expand(ctx.n_batch, p.size(0))

[P_val, q_val, A_val, l_val, u_val] = params

P = [spa.csc_matrix((to_numpy(P_val[i]), self.P_idx), shape=self.P_shape)
for i in range(self.n_batch)]
q = [to_numpy(q_val[i]) for i in range(self.n_batch)]
A = [spa.csc_matrix((to_numpy(A_val[i]), self.A_idx), shape=self.A_shape)
for i in range(self.n_batch)]
l = [to_numpy(l_val[i]) for i in range(self.n_batch)]
u = [to_numpy(u_val[i]) for i in range(self.n_batch)]
P = [spa.csc_matrix((to_numpy(P_val[i]), ctx.P_idx), shape=ctx.P_shape)
for i in range(ctx.n_batch)]
q = [to_numpy(q_val[i]) for i in range(ctx.n_batch)]
A = [spa.csc_matrix((to_numpy(A_val[i]), ctx.A_idx), shape=ctx.A_shape)
for i in range(ctx.n_batch)]
l = [to_numpy(l_val[i]) for i in range(ctx.n_batch)]
u = [to_numpy(u_val[i]) for i in range(ctx.n_batch)]

# Perform forward step solving the QPs
x_torch = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device)
x_torch = torch.zeros((ctx.n_batch, ctx.n), dtype=dtype, device=device)

x, y, z = [], [], []
for i in range(self.n_batch):
for i in range(ctx.n_batch):
# Solve QP
# TODO: Cache solver object in between
m = osqp.OSQP()
m.setup(P[i], q[i], A[i], l[i], u[i], verbose=self.verbose)
m.setup(P[i], q[i], A[i], l[i], u[i], verbose=ctx.verbose)
result = m.solve()
status = result.info.status
if status != 'solved':
Expand All @@ -112,75 +147,116 @@ def forward(self, P_val, q_val, A_val, l_val, u_val):
x_torch[i] = torch.from_numpy(result.x)

# Save stuff for backpropagation
self.backward_vars = (P, q, A, l, u, x, y, z)
ctx.backward_vars = (P, q, A, l, u, x, y, z)

# Return solutions
if not batch_mode:
x_torch = x_torch.squeeze(0)
return x_torch

def backward(self, dl_dx_val):

@staticmethod
def backward(ctx, dl_dx_val):
dtype = dl_dx_val.dtype
device = dl_dx_val.device

batch_mode = dl_dx_val.ndimension() == 2
if not batch_mode:
dl_dx_val = dl_dx_val.unsqueeze(0)

# Convert dl_dx to numpy
dl_dx = to_numpy(dl_dx_val).squeeze()
dl_dx = to_numpy(dl_dx_val)

# Extract data from forward pass
P, q, A, l, u, x, y, z = self.backward_vars
P, q, A, l, u, x, y, z = ctx.backward_vars

# Convert to torch tensors
nnz_P = len(self.P_idx[0])
nnz_A = len(self.A_idx[0])
dP = torch.zeros((self.n_batch, nnz_P), dtype=dtype, device=device)
dq = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device)
dA = torch.zeros((self.n_batch, nnz_A), dtype=dtype, device=device)
dl = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device)
du = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device)
nnz_P = len(ctx.P_idx[0])
nnz_A = len(ctx.A_idx[0])
dP = torch.zeros((ctx.n_batch, nnz_P), dtype=dtype, device=device)
dq = torch.zeros((ctx.n_batch, ctx.n), dtype=dtype, device=device)
dA = torch.zeros((ctx.n_batch, nnz_A), dtype=dtype, device=device)
dl = torch.zeros((ctx.n_batch, ctx.m), dtype=dtype, device=device)
du = torch.zeros((ctx.n_batch, ctx.m), dtype=dtype, device=device)

# TODO: Improve this, reuse OSQP, port it in C

for i in range(self.n_batch):
for i in range(ctx.n_batch):
# Construct linear system
# Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717
# Guess which linear constraints are lower-active, upper-active, free
ind_low = np.where(z[i] - l[i] < - y[i])[0]
ind_upp = np.where(u[i] - z[i] < y[i])[0]
n_low = len(ind_low)
n_upp = len(ind_upp)

# Form A_red from the assumed active constraints
A_red = spa.vstack([A[i][ind_low], A[i][ind_upp]])

# Form KKT linear system
KKT = spa.vstack([spa.hstack([P[i], A_red.T]),
spa.hstack([A_red, spa.csc_matrix((n_low + n_upp, n_low + n_upp))])])
rhs = np.hstack([dl_dx.squeeze(), np.zeros(n_low + n_upp)])

# Get solution
r_sol = sla.spsolve(KKT, rhs)
r_x = r_sol[:self.n]
r_yl = r_sol[self.n:self.n + n_low]
r_yu = r_sol[self.n + n_low:]
r_y = np.zeros(self.m)
r_y[ind_low] = r_yl
r_y[ind_upp] = r_yu

if ctx.diff_mode == DiffModes.ACTIVE:
# Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717
# Guess which linear constraints are lower-active, upper-active, free
ind_low = np.where(z[i] - l[i] < - y[i])[0]
ind_upp = np.where(u[i] - z[i] < y[i])[0]
n_low = len(ind_low)
n_upp = len(ind_upp)

# Form A_red from the assumed active constraints
A_red = spa.vstack([A[i][ind_low], A[i][ind_upp]])

# Form KKT linear system
KKT = spa.vstack([spa.hstack([P[i], A_red.T]),
spa.hstack([A_red, spa.csc_matrix((n_low + n_upp, n_low + n_upp))])])
rhs = np.hstack([dl_dx[i], np.zeros(n_low + n_upp)])

# Get solution
# r_sol = sla.spsolve(KKT, rhs)
r_sol = sla.lsqr(KKT, rhs)[0]

r_x = r_sol[:ctx.n]
r_yl = r_sol[ctx.n:ctx.n + n_low]
r_yu = r_sol[ctx.n + n_low:]
r_y = np.zeros(ctx.m)
r_y[ind_low] = r_yl
r_y[ind_upp] = r_yu

t = np.hstack([r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0
for j in range(ctx.m)])
dl[i] = torch.tensor(t)
t = np.hstack([r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0
for j in range(ctx.m)])
du[i] = torch.tensor(t)
elif ctx.diff_mode == DiffModes.FULL:
KKT_22 = np.zeros(ctx.m)
J = y[i] < 0.
KKT_22[J] = (A[i].dot(x[i]) - l[i])[J]
KKT_22[~J] = (A[i].dot(x[i]) - u[i])[~J]

# TODO: Better handle when the bounds are \pm\infty.
KKT_22[np.isinf(KKT_22)] = 0.

KKT_T = spa.vstack([spa.hstack([P[i], A[i].T.dot(spa.diags(y[i]))]),
spa.hstack([A[i], spa.diags(KKT_22)])])
rhs = np.hstack([dl_dx[i], np.zeros(ctx.m)])

# Get solution
r_sol = sla.lsqr(KKT_T, rhs)[0]

r_x = r_sol[:ctx.n]
r_y = r_sol[ctx.n:] * y[i]

dl[i][J] = torch.from_numpy(r_y[J]).to(dl.device)
du[i][~J] = torch.from_numpy(r_y[~J]).to(du.device)
else:
raise RuntimeError(f"Unrecognized differentiation mode")


# Extract derivatives
rows, cols = self.P_idx
rows, cols = ctx.P_idx
values = -.5 * (r_x[rows] * x[i][cols] + r_x[cols] * x[i][rows])
dP[i] = torch.from_numpy(values)

rows, cols = self.A_idx
rows, cols = ctx.A_idx
values = -(y[i][rows] * r_x[cols] + r_y[rows] * x[i][cols])
dA[i] = torch.from_numpy(values)
dq[i] = torch.from_numpy(-r_x)
dl[i] = torch.tensor(
[r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0
for j in range(self.m)])
du[i] = torch.tensor(
[r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0
for j in range(self.m)])

grads = (dP, dq, dA, dl, du)
grads = [dP, dq, dA, dl, du]

if not batch_mode:
for i, g in enumerate(grads):
grads[i] = g.squeeze()

grads += [None]*9

return grads
return tuple(grads)
45 changes: 24 additions & 21 deletions tests/test_osqpth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import osqp
from osqpth.osqpth import OSQP
from osqpth.osqpth import OSQP, DiffModes
import numpy.random as npr
import numpy as np
import torch
Expand All @@ -19,7 +19,8 @@


def get_grads(n_batch=1, n=10, m=3, P_scale=1.,
A_scale=1., u_scale=1., l_scale=1.):
A_scale=1., u_scale=1., l_scale=1.,
diff_mode=DiffModes.FULL):
assert(n_batch == 1)
npr.seed(1)
L = np.random.randn(n, n)
Expand All @@ -35,11 +36,11 @@ def get_grads(n_batch=1, n=10, m=3, P_scale=1.,
P, q, A, l, u, true_x = [x.astype(np.float64) for x in
[P, q, A, l, u, true_x]]

grads = get_grads_torch(P, q, A, l, u, true_x)
grads = get_grads_torch(P, q, A, l, u, true_x, diff_mode)
return [P, q, A, l, u, true_x], grads


def get_grads_torch(P, q, A, l, u, true_x):
def get_grads_torch(P, q, A, l, u, true_x, diff_mode):

P_idx = P.nonzero()
P_shape = P.shape
Expand All @@ -56,7 +57,7 @@ def get_grads_torch(P, q, A, l, u, true_x):
for x in [P_torch, q_torch, A_torch, l_torch, u_torch]:
x.requires_grad = True

x_hats = OSQP(P_idx, P_shape, A_idx, A_shape)(P_torch, q_torch, A_torch, l_torch, u_torch)
x_hats = OSQP(P_idx, P_shape, A_idx, A_shape, diff_mode=diff_mode)(P_torch, q_torch, A_torch, l_torch, u_torch)

dl_dxhat = x_hats.data - true_x_torch
x_hats.backward(dl_dxhat)
Expand All @@ -67,19 +68,21 @@ def get_grads_torch(P, q, A, l, u, true_x):

def test_dl_dp():
n, m = 5, 5
[P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads(
n=n, m=m, P_scale=100., A_scale=100.)

def f(q):
m = osqp.OSQP()
m.setup(P, q, A, l, u, verbose=verbose)
res = m.solve()
x_hat = res.x

return 0.5 * np.sum(np.square(x_hat - true_x))

dq_fd = nd.Gradient(f)(q)
if verbose:
print('dq_fd: ', np.round(dq_fd, decimals=4))
print('dq: ', np.round(dq, decimals=4))
npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL)
for diff_mode in DiffModes:
[P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads(
n=n, m=m, P_scale=100., A_scale=100., diff_mode=diff_mode)
print(f'--- {diff_mode.name}')

def f(q):
m = osqp.OSQP()
m.setup(P, q, A, l, u, verbose=False)
res = m.solve()
x_hat = res.x

return 0.5 * np.sum(np.square(x_hat - true_x))

dq_fd = nd.Gradient(f)(q)
if verbose:
print('dq_fd: ', np.round(dq_fd, decimals=4))
print('dq: ', np.round(dq, decimals=4))
npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL)

0 comments on commit 007680a

Please sign in to comment.