Skip to content

Commit

Permalink
Python linsys: Add normalization for bodono#13.
Browse files Browse the repository at this point in the history
  • Loading branch information
Brandon Amos committed Apr 8, 2019
1 parent 7634ea9 commit 4f128fc
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 14 deletions.
66 changes: 64 additions & 2 deletions python_linsys/private.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
#include "private.h"

// The following are shared with scsmodule.c, which
// sets the callbacks.
// sets the callbacks and defines helper functions.
extern PyObject *scs_init_lin_sys_work_cb;
extern PyObject *scs_solve_lin_sys_cb;
extern PyObject *scs_accum_by_a_cb;
extern PyObject *scs_accum_by_atrans_cb;
extern PyObject *scs_normalize_a_cb;
extern PyObject *scs_un_normalize_a_cb;

extern int scs_get_float_type(void);
extern int scs_get_int_type(void);
extern PyArrayObject *scs_get_contiguous(PyArrayObject *array, int typenum);

char *SCS(get_lin_sys_method)(const ScsMatrix *A, const ScsSettings *stgs) {
char *str = (char *)scs_malloc(sizeof(char) * 128);
Expand Down Expand Up @@ -73,10 +79,14 @@ void SCS(accum_by_a)(const ScsMatrix *A, ScsLinSysWork *p, const scs_float *x,

ScsLinSysWork *SCS(init_lin_sys_work)(const ScsMatrix *A,
const ScsSettings *stgs) {
_import_array(); // TODO: Move this somewhere else?
_import_array();

ScsLinSysWork *p = (ScsLinSysWork *)scs_calloc(1, sizeof(ScsLinSysWork));
p->total_solve_time = 0;

PyObject *arglist = Py_BuildValue("(d)", stgs->rho_x);
PyObject_CallObject(scs_init_lin_sys_work_cb, arglist);

return p;
}

Expand Down Expand Up @@ -104,3 +114,55 @@ scs_int SCS(solve_lin_sys)(const ScsMatrix *A, const ScsSettings *stgs,
p->total_solve_time += SCS(tocq)(&linsys_timer);
return 0;
}


void SCS(normalize_a)(ScsMatrix *A, const ScsSettings *stgs,
const ScsCone *k, ScsScaling *scal) {
_import_array();

int scs_int_type = scs_get_int_type();
int scs_float_type = scs_get_float_type();

scs_int *boundaries;
npy_intp veclen[1];
veclen[0] = SCS(get_cone_boundaries)(k, &boundaries);
PyObject *boundaries_py = PyArray_SimpleNewFromData(
1, veclen, scs_int_type, boundaries);
PyArray_ENABLEFLAGS((PyArrayObject *)boundaries_py, NPY_ARRAY_OWNDATA);

PyObject *arglist = Py_BuildValue("(Od)", boundaries_py, stgs->scale);
PyObject *result = PyObject_CallObject(scs_normalize_a_cb, arglist);
scs_free(boundaries);

PyArrayObject *D_py = SCS_NULL;
PyArrayObject *E_py = SCS_NULL;
PyArg_ParseTuple(result, "O!O!dd", &PyArray_Type, &D_py,
&PyArray_Type, &E_py,
&scal->mean_norm_row_a, &scal->mean_norm_col_a);

D_py = scs_get_contiguous(D_py, scs_float_type);
E_py = scs_get_contiguous(E_py, scs_float_type);

scal->D = (scs_float *)PyArray_DATA(D_py);
scal->E = (scs_float *)PyArray_DATA(E_py);
}


void SCS(un_normalize_a)(ScsMatrix *A, const ScsSettings *stgs,
const ScsScaling *scal) {
int scs_float_type = scs_get_float_type();

npy_intp veclen[1];
veclen[0] = A->m;
PyObject *D_py = PyArray_SimpleNewFromData(1, veclen,
scs_float_type, scal->D);
PyArray_ENABLEFLAGS((PyArrayObject *)D_py, NPY_ARRAY_OWNDATA);

veclen[0] = A->n;
PyObject *E_py = PyArray_SimpleNewFromData(1, veclen,
scs_float_type, scal->E);
PyArray_ENABLEFLAGS((PyArrayObject *)E_py, NPY_ARRAY_OWNDATA);

PyObject *arglist = Py_BuildValue("(OO)", D_py, E_py);
PyObject_CallObject(scs_un_normalize_a_cb, arglist);
}
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def install_scs(**kwargs):
'src/scsmodule.c',
] + glob('scs/src/*.c') + glob('scs/linsys/*.c')
include_dirs = ['scs/include', 'scs/linsys']
define_macros = [('PYTHON', None), ('CTRLC', 1), ('COPYAMATRIX', None)]
define_macros = [('PYTHON', None), ('CTRLC', 1)]

if system() == 'Linux':
libraries += ['rt']
Expand All @@ -159,7 +159,7 @@ def install_scs(**kwargs):
sources=sources + glob('scs/linsys/direct/*.c') +
glob('scs/linsys/direct/external/amd/*.c') +
glob('scs/linsys/direct/external/qdldl/*.c'),
define_macros=list(define_macros),
define_macros=list(define_macros) + [('COPYAMATRIX', None)],
include_dirs=include_dirs +
['scs/linsys/direct/', 'scs/linsys/direct/external/'],
libraries=list(libraries),
Expand All @@ -168,7 +168,8 @@ def install_scs(**kwargs):
_scs_indirect = Extension(
name='_scs_indirect',
sources=sources + glob('scs/linsys/indirect/*.c'),
define_macros=list(define_macros) + [('INDIRECT', None)],
define_macros=list(define_macros) + \
[('COPYAMATRIX', None), ('INDIRECT', None)],
include_dirs=include_dirs + ['scs/linsys/indirect/'],
libraries=list(libraries),
extra_compile_args=list(extra_compile_args))
Expand Down
38 changes: 29 additions & 9 deletions src/scsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ struct ScsPyData {
};


PyObject *scs_init_lin_sys_work_cb = SCS_NULL;
PyObject *scs_solve_lin_sys_cb = SCS_NULL;
PyObject *scs_accum_by_a_cb = SCS_NULL;
PyObject *scs_accum_by_atrans_cb = SCS_NULL;
PyObject *scs_normalize_a_cb = SCS_NULL;
PyObject *scs_un_normalize_a_cb = SCS_NULL;

/* Note, Python3.x may require special handling for the scs_int and scs_float
* types. */
Expand Down Expand Up @@ -68,7 +71,7 @@ int scs_get_float_type(void) {
}
}

static PyArrayObject *get_contiguous(PyArrayObject *array, int typenum) {
PyArrayObject *scs_get_contiguous(PyArrayObject *array, int typenum) {
/* gets the pointer to the block of contiguous C memory */
/* the overhead should be small unless the numpy array has been */
/* reordered in some way or the data type doesn't quite match */
Expand Down Expand Up @@ -126,7 +129,7 @@ static scs_int get_warm_start(char *key, scs_float **x, scs_int l,
PySys_WriteStderr("Error parsing warm-start input\n");
return 0;
} else {
PyArrayObject *px0 = get_contiguous(x0, scs_get_float_type());
PyArrayObject *px0 = scs_get_contiguous(x0, scs_get_float_type());
memcpy(*x, (scs_float *)PyArray_DATA(px0), l * sizeof(scs_float));
Py_DECREF(px0);
return 1;
Expand Down Expand Up @@ -300,7 +303,7 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) {
char *argparse_string = "(ll)O!O!O!O!O!O!|O!O!O!lffffflz";
char *outarg_string = "{s:l,s:l,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:s}";
#else
char *argparse_string = "(ll)O!O!O!O!O!O!|O!O!O!ldddddlz(OOO)";
char *argparse_string = "(ll)O!O!O!O!O!O!|O!O!O!ldddddlz(OOOOOO)";
char *outarg_string = "{s:l,s:l,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:s}";
#endif
#else
Expand Down Expand Up @@ -335,7 +338,9 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) {
&(d->stgs->cg_rate), &(d->stgs->alpha), &(d->stgs->rho_x),
&(d->stgs->acceleration_lookback),
&(d->stgs->write_data_filename),
&scs_solve_lin_sys_cb, &scs_accum_by_a_cb, &scs_accum_by_atrans_cb)) {
&scs_init_lin_sys_work_cb, &scs_solve_lin_sys_cb,
&scs_accum_by_a_cb, &scs_accum_by_atrans_cb,
&scs_normalize_a_cb, &scs_un_normalize_a_cb)) {
PySys_WriteStderr("error parsing inputs\n");
return SCS_NULL;
}
Expand All @@ -360,9 +365,9 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) {
if (!PyArray_ISINTEGER(Ap) || PyArray_NDIM(Ap) != 1) {
return finish_with_error(d, k, &ps, "Ap must be a numpy array of ints");
}
ps.Ax = get_contiguous(Ax, scs_float_type);
ps.Ai = get_contiguous(Ai, scs_int_type);
ps.Ap = get_contiguous(Ap, scs_int_type);
ps.Ax = scs_get_contiguous(Ax, scs_float_type);
ps.Ai = scs_get_contiguous(Ai, scs_int_type);
ps.Ap = scs_get_contiguous(Ap, scs_int_type);

A = (ScsMatrix *)scs_malloc(sizeof(ScsMatrix));
A->n = d->n;
Expand All @@ -379,7 +384,7 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) {
if (PyArray_DIM(c, 0) != d->n) {
return finish_with_error(d, k, &ps, "c has incompatible dimension with A");
}
ps.c = get_contiguous(c, scs_float_type);
ps.c = scs_get_contiguous(c, scs_float_type);
d->c = (scs_float *)PyArray_DATA(ps.c);
/* set b */
if (!PyArray_ISFLOAT(b) || PyArray_NDIM(b) != 1) {
Expand All @@ -389,7 +394,7 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) {
if (PyArray_DIM(b, 0) != d->m) {
return finish_with_error(d, k, &ps, "b has incompatible dimension with A");
}
ps.b = get_contiguous(b, scs_float_type);
ps.b = scs_get_contiguous(b, scs_float_type);
d->b = (scs_float *)PyArray_DATA(ps.b);

if (get_pos_int_param("f", &(k->f), 0, cone) < 0) {
Expand Down Expand Up @@ -449,6 +454,11 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) {
}

#ifdef PYTHON_LINSYS
if (!PyCallable_Check(scs_init_lin_sys_work_cb)) {
PyErr_SetString(PyExc_ValueError, "scs_init_lin_sys_work_cb not a valid callback");
return SCS_NULL;
}

if (!PyCallable_Check(scs_solve_lin_sys_cb)) {
PyErr_SetString(PyExc_ValueError, "scs_solve_lin_sys_cb not a valid callback");
return SCS_NULL;
Expand All @@ -463,6 +473,16 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) {
PyErr_SetString(PyExc_ValueError, "scs_accum_by_atrans_cb not a valid callback");
return SCS_NULL;
}

if (!PyCallable_Check(scs_normalize_a_cb)) {
PyErr_SetString(PyExc_ValueError, "scs_normalize_a_cb not a valid callback");
return SCS_NULL;
}

if (!PyCallable_Check(scs_un_normalize_a_cb)) {
PyErr_SetString(PyExc_ValueError, "scs_un_normalize_a_cb not a valid callback");
return SCS_NULL;
}
#endif


Expand Down

0 comments on commit 4f128fc

Please sign in to comment.