From c2df5307e5a8ea11ee3b0d52bdcc68f4a658ae54 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Tue, 23 Jul 2019 16:06:38 -0700 Subject: [PATCH] Make the interface more flexible w.r.t. the device and dtype --- osqpth/osqpth.py | 61 ++++++++++++++++++++++++-------------------- tests/test_osqpth.py | 0 2 files changed, 34 insertions(+), 27 deletions(-) mode change 100644 => 100755 tests/test_osqpth.py diff --git a/osqpth/osqpth.py b/osqpth/osqpth.py index 54a8cdf..a043811 100644 --- a/osqpth/osqpth.py +++ b/osqpth/osqpth.py @@ -35,14 +35,14 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): \hat x = argmin_x 1/2 x' P x + q' x subject to l <= Ax <= u - + where P \in S^{n,n}, S^{n,n} is the set of all positive semi-definite matrices, q \in R^{n} A \in R^{m,n} l \in R^{m} u \in R^{m} - + These parameters should all be passed to this function as Variable- or Parameter-wrapped Tensors. (See torch.autograd.Variable and torch.nn.parameter.Parameter) @@ -57,7 +57,7 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): that will not change across all of the minibatch examples. This function is able to infer such cases. - If you don't want to use any constraints, you can set the + If you don't want to use any constraints, you can set the appropriate values to: e = Variable(torch.Tensor()) @@ -67,17 +67,19 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): # Convert batches to sparse matrices/vectors self.n_batch = P_val.size(0) if len(P_val.size()) > 1 else 1 self.m, self.n = self.A_shape # Problem size - Tensor = torch.cuda.DoubleTensor if P_val.is_cuda else torch.DoubleTensor + + dtype = P_val.dtype + device = P_val.device # Convert P and A to sparse matrices # TODO (Bart): create CSC matrix during initialization. Then # just reassign the mat.data vector with A_val and P_val - + if self.n_batch == 1: # Create lists to make the code below work # TODO (Bart): Find a better way to do this P_val, q_val, A_val, l_val, u_val = [P_val], [q_val], [A_val], [l_val], [u_val] - + P = [spa.csc_matrix((to_numpy(P_val[i]), self.P_idx), shape=self.P_shape) for i in range(self.n_batch)] q = [to_numpy(q_val[i]) for i in range(self.n_batch)] @@ -87,7 +89,7 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): u = [to_numpy(u_val[i]) for i in range(self.n_batch)] # Perform forward step solving the QPs - x_torch = Tensor().new_empty((self.n_batch, self.n), dtype=torch.double) + x_torch = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device) x, y, z = [], [], [] for i in range(self.n_batch): @@ -99,18 +101,21 @@ def forward(self, P_val, q_val, A_val, l_val, u_val): x.append(result.x) y.append(result.y) z.append(A[i].dot(result.x)) - x_torch[i] = Tensor(result.x) + + # This is silently converting result.x to the same + # dtype and device as x_torch. + x_torch[i] = torch.from_numpy(result.x) # Save stuff for backpropagation self.backward_vars = (P, q, A, l, u, x, y, z) - + # Return solutions return x_torch def backward(self, dl_dx_val): - - Tensor = torch.cuda.DoubleTensor if dl_dx_val.is_cuda else torch.DoubleTensor - + dtype = dl_dx_val.dtype + device = dl_dx_val.device + # Convert dl_dx to numpy dl_dx = to_numpy(dl_dx_val).squeeze() @@ -119,15 +124,15 @@ def backward(self, dl_dx_val): # Convert to torch tensors nnz_P = len(self.P_idx[0]) - nnz_A = len(self.A_idx[0]) - dP = Tensor().new_empty((self.n_batch, nnz_P), dtype=torch.double) - dq = Tensor().new_empty((self.n_batch, self.n), dtype=torch.double) - dA = Tensor().new_empty((self.n_batch, nnz_A), dtype=torch.double) - dl = Tensor().new_empty((self.n_batch, self.m), dtype=torch.double) - du = Tensor().new_empty((self.n_batch, self.m), dtype=torch.double) - + nnz_A = len(self.A_idx[0]) + dP = torch.zeros((self.n_batch, nnz_P), dtype=dtype, device=device) + dq = torch.zeros((self.n_batch, self.n), dtype=dtype, device=device) + dA = torch.zeros((self.n_batch, nnz_A), dtype=dtype, device=device) + dl = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device) + du = torch.zeros((self.n_batch, self.m), dtype=dtype, device=device) + # TODO: Improve this, reuse OSQP, port it in C - + for i in range(self.n_batch): # Construct linear system # Taken from https://github.com/oxfordcontrol/osqp-python/blob/0363d028b2321017049360d2eb3c0cf206028c43/modulepurepy/_osqp.py#L1717 @@ -157,16 +162,18 @@ def backward(self, dl_dx_val): # Extract derivatives rows, cols = self.P_idx values = -.5 * (r_x[rows] * x[i][cols] + r_x[cols] * x[i][rows]) - dP[i] = Tensor(values) + dP[i] = torch.from_numpy(values) rows, cols = self.A_idx values = -(y[i][rows] * r_x[cols] + r_y[rows] * x[i][cols]) - dA[i] = Tensor(values) - dq[i] = Tensor(-r_x) - dl[i] = Tensor([r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 - for j in range(self.m)]) - du[i] = Tensor([r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 - for j in range(self.m)]) + dA[i] = torch.from_numpy(values) + dq[i] = torch.from_numpy(-r_x) + dl[i] = torch.tensor( + [r_yl[np.where(ind_low == j)[0]] if j in ind_low else 0 + for j in range(self.m)]) + du[i] = torch.tensor( + [r_yu[np.where(ind_upp == j)[0]] if j in ind_upp else 0 + for j in range(self.m)]) grads = (dP, dq, dA, dl, du) diff --git a/tests/test_osqpth.py b/tests/test_osqpth.py old mode 100644 new mode 100755