Skip to content

Commit

Permalink
Simplify warm start routine
Browse files Browse the repository at this point in the history
  • Loading branch information
imciner2 committed Dec 4, 2023
1 parent b2b3be0 commit 38b984f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 42 deletions.
18 changes: 3 additions & 15 deletions @osqp/warm_start.m
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,10 @@ function warm_start(this, varargin)
assert(isempty(x) || length(x) == n, 'input ''x'' is the wrong size');
assert(isempty(y) || length(y) == m, 'input ''y'' is the wrong size');


% Decide which function to call
if (~isempty(x) && isempty(y))
osqp_mex('warm_start_x', this.objectHandle, x);
return;
end

if (isempty(x) && ~isempty(y))
osqp_mex('warm_start_y', this.objectHandle, y);
end

if (~isempty(x) && ~isempty(y))
% Only call when there is a vector to update
if (~isempty(x) || ~isempty(y))
osqp_mex('warm_start', this.objectHandle, x, y);
end

if (isempty(x) && isempty(y))
else
error('Unrecognized fields');
end
end
38 changes: 13 additions & 25 deletions c_sources/osqp_mex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static void setToNaN(double* arr_out, OSQPInt len){

// Main mex function
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
{
// OSQP solver wrapper
OsqpData* osqpData;

Expand All @@ -48,7 +48,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

if (nrhs < 1 || mxGetString(prhs[0], cmd, sizeof(cmd)))
mexErrMsgTxt("First input should be a command string less than 64 characters long.");

/*
* First check to see if a new object was requested
*/
Expand Down Expand Up @@ -156,7 +156,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

// delete the object and its data
if (!strcmp("delete", cmd)) {

osqp_cleanup(osqpData->solver);
destroyObject<OsqpData>(prhs[1]);
// Warn if other commands were ignored
Expand Down Expand Up @@ -202,7 +202,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if(!osqpData->solver){
mexErrMsgTxt("Solver is uninitialized. No settings have been configured.");
}

OSQPFloat rho = (OSQPFloat)mxGetScalar(prhs[2]);

osqp_update_rho(osqpData->solver, rho);
Expand Down Expand Up @@ -355,14 +355,14 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

if (!exitflag && (!mxIsEmpty(q) || !mxIsEmpty(l) || !mxIsEmpty(u))) {
exitflag = osqp_update_data_vec(osqpData->solver, q_vec, l_vec, u_vec);
if (exitflag) exitflag=1;
if (exitflag) exitflag=1;
}

if (!exitflag && (!mxIsEmpty(Px) || !mxIsEmpty(Ax))) {
exitflag = osqp_update_data_mat(osqpData->solver, Px_vec, Px_idx_vec, Px_n, Ax_vec, Ax_idx_vec, Ax_n);
if (exitflag) exitflag=2;
}


// Free vectors
if(!mxIsEmpty(q)) c_free(q_vec);
Expand All @@ -384,36 +384,24 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
return;
}

if (!strcmp("warm_start", cmd) || !strcmp("warm_start_x", cmd) || !strcmp("warm_start_y", cmd)) {
if (!strcmp("warm_start", cmd)) {

//throw an error if this is called before solver is configured
if(!osqpData->solver){
mexErrMsgTxt("Solver has not been initialized.");
}

// Fill x and y
const mxArray *x = NULL;
const mxArray *y = NULL;
if (!strcmp("warm_start", cmd)) {
x = prhs[2];
y = prhs[3];
}
else if (!strcmp("warm_start_x", cmd)) {
x = prhs[2];
y = NULL;
}

else if (!strcmp("warm_start_y", cmd)) {
x = NULL;
y = prhs[2];
}
// Fill x and y
const mxArray *x = prhs[2];
const mxArray *y = prhs[3];

// Copy vectors to ensure they are cast as OSQPFloat
OSQPFloat *x_vec = NULL;
OSQPFloat *y_vec = NULL;

OSQPInt n, m;
osqp_get_dimensions(osqpData->solver, &m, &n);

if(!mxIsEmpty(x)){
x_vec = cloneVector<OSQPFloat>(mxGetPr(x),n);
}
Expand Down
90 changes: 88 additions & 2 deletions unittests/warm_start_tests.m
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ function setup_problem(testCase)
end

methods (Test)
function test_warm_start(testCase)

function test_warm_start_zeros(testCase)
% big example
rng(4)
testCase.n = 100;
Expand Down Expand Up @@ -63,12 +62,99 @@ function test_warm_start(testCase)
testCase.solver.warm_start('x', zeros(testCase.n, 1), 'y', zeros(testCase.m, 1));
results = testCase.solver.solve();
testCase.verifyEqual(results.info.iter, tot_iter, 'AbsTol', testCase.tol)
end

function test_warm_start_optimal(testCase)
% big example
rng(4)
testCase.n = 100;
testCase.m = 200;
Pt = sprandn(testCase.n, testCase.n, 0.6);
testCase.P = Pt' * Pt;
testCase.q = randn(testCase.n, 1);
testCase.A = sprandn(testCase.m, testCase.n, 0.8);
testCase.u = 2*rand(testCase.m, 1);
testCase.l = -2*rand(testCase.m, 1);

% Setup solver
testCase.solver = osqp;
testCase.solver.setup(testCase.P, testCase.q, ...
testCase.A, testCase.l, testCase.u, testCase.options);

% Solve with OSQP
results = testCase.solver.solve();

% Store optimal values
x_opt = results.x;
y_opt = results.y;
tot_iter = results.info.iter;

% Warm start with optimal values and check that number of iterations is < 10
testCase.solver.warm_start('x', x_opt, 'y', y_opt);
results = testCase.solver.solve();
testCase.verifyThat(results.info.iter, matlab.unittest.constraints.IsLessThan(10));
end

function test_warm_start_duals(testCase)
% big example
rng(4)
testCase.n = 100;
testCase.m = 200;
Pt = sprandn(testCase.n, testCase.n, 0.6);
testCase.P = Pt' * Pt;
testCase.q = randn(testCase.n, 1);
testCase.A = sprandn(testCase.m, testCase.n, 0.8);
testCase.u = 2*rand(testCase.m, 1);
testCase.l = -2*rand(testCase.m, 1);

% Setup solver
testCase.solver = osqp;
testCase.solver.setup(testCase.P, testCase.q, ...
testCase.A, testCase.l, testCase.u, testCase.options);

% Solve with OSQP
results = testCase.solver.solve();

% Store optimal values
x_opt = results.x;
y_opt = results.y;
tot_iter = results.info.iter;

% Warm start with zeros for dual variables
testCase.solver.warm_start('y', zeros(testCase.m, 1));
results = testCase.solver.solve();
testCase.verifyEqual(results.y, y_opt, 'AbsTol', testCase.tol)
end

function test_warm_start_primal(testCase)
% big example
rng(4)
testCase.n = 100;
testCase.m = 200;
Pt = sprandn(testCase.n, testCase.n, 0.6);
testCase.P = Pt' * Pt;
testCase.q = randn(testCase.n, 1);
testCase.A = sprandn(testCase.m, testCase.n, 0.8);
testCase.u = 2*rand(testCase.m, 1);
testCase.l = -2*rand(testCase.m, 1);

% Setup solver
testCase.solver = osqp;
testCase.solver.setup(testCase.P, testCase.q, ...
testCase.A, testCase.l, testCase.u, testCase.options);

% Solve with OSQP
results = testCase.solver.solve();

% Store optimal values
x_opt = results.x;
y_opt = results.y;
tot_iter = results.info.iter;

% Warm start with zeros for primal variables
testCase.solver.warm_start('x', zeros(testCase.n, 1));
results = testCase.solver.solve();
testCase.verifyEqual(results.x, x_opt, 'AbsTol', testCase.tol)
end

end
Expand Down

0 comments on commit 38b984f

Please sign in to comment.