Skip to content

Commit

Permalink
For now, raise an error if a QP isn't solved
Browse files Browse the repository at this point in the history
  • Loading branch information
bamos committed Jul 26, 2019
1 parent 796f2f8 commit de73652
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions osqpth/osqpth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def forward(self, P_val, q_val, A_val, l_val, u_val):
\hat x = argmin_x 1/2 x' P x + q' x
subject to l <= Ax <= u
where P \in S^{n,n},
S^{n,n} is the set of all positive semi-definite matrices,
q \in R^{n}
A \in R^{m,n}
l \in R^{m}
u \in R^{m}
These parameters should all be passed to this function as
Variable- or Parameter-wrapped Tensors.
(See torch.autograd.Variable and torch.nn.parameter.Parameter)
Expand All @@ -57,7 +57,7 @@ def forward(self, P_val, q_val, A_val, l_val, u_val):
that will not change across all of the minibatch examples.
This function is able to infer such cases.
If you don't want to use any constraints, you can set the
If you don't want to use any constraints, you can set the
appropriate values to:
e = Variable(torch.Tensor())
Expand All @@ -72,12 +72,12 @@ def forward(self, P_val, q_val, A_val, l_val, u_val):
# Convert P and A to sparse matrices
# TODO (Bart): create CSC matrix during initialization. Then
# just reassign the mat.data vector with A_val and P_val

if self.n_batch == 1:
# Create lists to make the code below work
# TODO (Bart): Find a better way to do this
P_val, q_val, A_val, l_val, u_val = [P_val], [q_val], [A_val], [l_val], [u_val]

P = [spa.csc_matrix((to_numpy(P_val[i]), self.P_idx), shape=self.P_shape)
for i in range(self.n_batch)]
q = [to_numpy(q_val[i]) for i in range(self.n_batch)]
Expand All @@ -96,21 +96,26 @@ def forward(self, P_val, q_val, A_val, l_val, u_val):
m = osqp.OSQP()
m.setup(P[i], q[i], A[i], l[i], u[i], verbose=self.verbose)
result = m.solve()
status = result.info.status
if status != 'solved':
# TODO: We can replace this with something calmer and
# add some more options around potentially ignoring this.
raise RuntimeError(f"Unable to solve QP, status: {status}")
x.append(result.x)
y.append(result.y)
z.append(A[i].dot(result.x))
x_torch[i] = Tensor(result.x)

# Save stuff for backpropagation
self.backward_vars = (P, q, A, l, u, x, y, z)

# Return solutions
return x_torch

def backward(self, dl_dx_val):

Tensor = torch.cuda.DoubleTensor if dl_dx_val.is_cuda else torch.DoubleTensor

# Convert dl_dx to numpy
dl_dx = to_numpy(dl_dx_val).squeeze()

Expand Down

0 comments on commit de73652

Please sign in to comment.