diff --git a/python_linsys/private.c b/python_linsys/private.c index 1855d180..413b0d20 100644 --- a/python_linsys/private.c +++ b/python_linsys/private.c @@ -4,11 +4,17 @@ #include "private.h" // The following are shared with scsmodule.c, which -// sets the callbacks. +// sets the callbacks and defines helper functions. +extern PyObject *scs_init_lin_sys_work_cb; extern PyObject *scs_solve_lin_sys_cb; extern PyObject *scs_accum_by_a_cb; extern PyObject *scs_accum_by_atrans_cb; +extern PyObject *scs_normalize_a_cb; +extern PyObject *scs_un_normalize_a_cb; + extern int scs_get_float_type(void); +extern int scs_get_int_type(void); +extern PyArrayObject *scs_get_contiguous(PyArrayObject *array, int typenum); char *SCS(get_lin_sys_method)(const ScsMatrix *A, const ScsSettings *stgs) { char *str = (char *)scs_malloc(sizeof(char) * 128); @@ -73,10 +79,14 @@ void SCS(accum_by_a)(const ScsMatrix *A, ScsLinSysWork *p, const scs_float *x, ScsLinSysWork *SCS(init_lin_sys_work)(const ScsMatrix *A, const ScsSettings *stgs) { - _import_array(); // TODO: Move this somewhere else? + _import_array(); ScsLinSysWork *p = (ScsLinSysWork *)scs_calloc(1, sizeof(ScsLinSysWork)); p->total_solve_time = 0; + + PyObject *arglist = Py_BuildValue("(d)", stgs->rho_x); + PyObject_CallObject(scs_init_lin_sys_work_cb, arglist); + return p; } @@ -104,3 +114,55 @@ scs_int SCS(solve_lin_sys)(const ScsMatrix *A, const ScsSettings *stgs, p->total_solve_time += SCS(tocq)(&linsys_timer); return 0; } + + +void SCS(normalize_a)(ScsMatrix *A, const ScsSettings *stgs, + const ScsCone *k, ScsScaling *scal) { + _import_array(); + + int scs_int_type = scs_get_int_type(); + int scs_float_type = scs_get_float_type(); + + scs_int *boundaries; + npy_intp veclen[1]; + veclen[0] = SCS(get_cone_boundaries)(k, &boundaries); + PyObject *boundaries_py = PyArray_SimpleNewFromData( + 1, veclen, scs_int_type, boundaries); + PyArray_ENABLEFLAGS((PyArrayObject *)boundaries_py, NPY_ARRAY_OWNDATA); + + PyObject *arglist = Py_BuildValue("(Od)", boundaries_py, stgs->scale); + PyObject *result = PyObject_CallObject(scs_normalize_a_cb, arglist); + scs_free(boundaries); + + PyArrayObject *D_py = SCS_NULL; + PyArrayObject *E_py = SCS_NULL; + PyArg_ParseTuple(result, "O!O!dd", &PyArray_Type, &D_py, + &PyArray_Type, &E_py, + &scal->mean_norm_row_a, &scal->mean_norm_col_a); + + D_py = scs_get_contiguous(D_py, scs_float_type); + E_py = scs_get_contiguous(E_py, scs_float_type); + + scal->D = (scs_float *)PyArray_DATA(D_py); + scal->E = (scs_float *)PyArray_DATA(E_py); +} + + +void SCS(un_normalize_a)(ScsMatrix *A, const ScsSettings *stgs, + const ScsScaling *scal) { + int scs_float_type = scs_get_float_type(); + + npy_intp veclen[1]; + veclen[0] = A->m; + PyObject *D_py = PyArray_SimpleNewFromData(1, veclen, + scs_float_type, scal->D); + PyArray_ENABLEFLAGS((PyArrayObject *)D_py, NPY_ARRAY_OWNDATA); + + veclen[0] = A->n; + PyObject *E_py = PyArray_SimpleNewFromData(1, veclen, + scs_float_type, scal->E); + PyArray_ENABLEFLAGS((PyArrayObject *)E_py, NPY_ARRAY_OWNDATA); + + PyObject *arglist = Py_BuildValue("(OO)", D_py, E_py); + PyObject_CallObject(scs_un_normalize_a_cb, arglist); +} diff --git a/setup.py b/setup.py index 3598b0bc..308c6b65 100644 --- a/setup.py +++ b/setup.py @@ -141,7 +141,7 @@ def install_scs(**kwargs): 'src/scsmodule.c', ] + glob('scs/src/*.c') + glob('scs/linsys/*.c') include_dirs = ['scs/include', 'scs/linsys'] - define_macros = [('PYTHON', None), ('CTRLC', 1), ('COPYAMATRIX', None)] + define_macros = [('PYTHON', None), ('CTRLC', 1)] if system() == 'Linux': libraries += ['rt'] @@ -159,7 +159,7 @@ def install_scs(**kwargs): sources=sources + glob('scs/linsys/direct/*.c') + glob('scs/linsys/direct/external/amd/*.c') + glob('scs/linsys/direct/external/qdldl/*.c'), - define_macros=list(define_macros), + define_macros=list(define_macros) + [('COPYAMATRIX', None)], include_dirs=include_dirs + ['scs/linsys/direct/', 'scs/linsys/direct/external/'], libraries=list(libraries), @@ -168,7 +168,8 @@ def install_scs(**kwargs): _scs_indirect = Extension( name='_scs_indirect', sources=sources + glob('scs/linsys/indirect/*.c'), - define_macros=list(define_macros) + [('INDIRECT', None)], + define_macros=list(define_macros) + \ + [('COPYAMATRIX', None), ('INDIRECT', None)], include_dirs=include_dirs + ['scs/linsys/indirect/'], libraries=list(libraries), extra_compile_args=list(extra_compile_args)) diff --git a/src/scsmodule.c b/src/scsmodule.c index b7bdb30f..4d96a675 100644 --- a/src/scsmodule.c +++ b/src/scsmodule.c @@ -34,9 +34,12 @@ struct ScsPyData { }; +PyObject *scs_init_lin_sys_work_cb = SCS_NULL; PyObject *scs_solve_lin_sys_cb = SCS_NULL; PyObject *scs_accum_by_a_cb = SCS_NULL; PyObject *scs_accum_by_atrans_cb = SCS_NULL; +PyObject *scs_normalize_a_cb = SCS_NULL; +PyObject *scs_un_normalize_a_cb = SCS_NULL; /* Note, Python3.x may require special handling for the scs_int and scs_float * types. */ @@ -68,7 +71,7 @@ int scs_get_float_type(void) { } } -static PyArrayObject *get_contiguous(PyArrayObject *array, int typenum) { +PyArrayObject *scs_get_contiguous(PyArrayObject *array, int typenum) { /* gets the pointer to the block of contiguous C memory */ /* the overhead should be small unless the numpy array has been */ /* reordered in some way or the data type doesn't quite match */ @@ -126,7 +129,7 @@ static scs_int get_warm_start(char *key, scs_float **x, scs_int l, PySys_WriteStderr("Error parsing warm-start input\n"); return 0; } else { - PyArrayObject *px0 = get_contiguous(x0, scs_get_float_type()); + PyArrayObject *px0 = scs_get_contiguous(x0, scs_get_float_type()); memcpy(*x, (scs_float *)PyArray_DATA(px0), l * sizeof(scs_float)); Py_DECREF(px0); return 1; @@ -300,7 +303,7 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) { char *argparse_string = "(ll)O!O!O!O!O!O!|O!O!O!lffffflz"; char *outarg_string = "{s:l,s:l,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:s}"; #else - char *argparse_string = "(ll)O!O!O!O!O!O!|O!O!O!ldddddlz(OOO)"; + char *argparse_string = "(ll)O!O!O!O!O!O!|O!O!O!ldddddlz(OOOOOO)"; char *outarg_string = "{s:l,s:l,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:s}"; #endif #else @@ -335,7 +338,9 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) { &(d->stgs->cg_rate), &(d->stgs->alpha), &(d->stgs->rho_x), &(d->stgs->acceleration_lookback), &(d->stgs->write_data_filename), - &scs_solve_lin_sys_cb, &scs_accum_by_a_cb, &scs_accum_by_atrans_cb)) { + &scs_init_lin_sys_work_cb, &scs_solve_lin_sys_cb, + &scs_accum_by_a_cb, &scs_accum_by_atrans_cb, + &scs_normalize_a_cb, &scs_un_normalize_a_cb)) { PySys_WriteStderr("error parsing inputs\n"); return SCS_NULL; } @@ -360,9 +365,9 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) { if (!PyArray_ISINTEGER(Ap) || PyArray_NDIM(Ap) != 1) { return finish_with_error(d, k, &ps, "Ap must be a numpy array of ints"); } - ps.Ax = get_contiguous(Ax, scs_float_type); - ps.Ai = get_contiguous(Ai, scs_int_type); - ps.Ap = get_contiguous(Ap, scs_int_type); + ps.Ax = scs_get_contiguous(Ax, scs_float_type); + ps.Ai = scs_get_contiguous(Ai, scs_int_type); + ps.Ap = scs_get_contiguous(Ap, scs_int_type); A = (ScsMatrix *)scs_malloc(sizeof(ScsMatrix)); A->n = d->n; @@ -379,7 +384,7 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) { if (PyArray_DIM(c, 0) != d->n) { return finish_with_error(d, k, &ps, "c has incompatible dimension with A"); } - ps.c = get_contiguous(c, scs_float_type); + ps.c = scs_get_contiguous(c, scs_float_type); d->c = (scs_float *)PyArray_DATA(ps.c); /* set b */ if (!PyArray_ISFLOAT(b) || PyArray_NDIM(b) != 1) { @@ -389,7 +394,7 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) { if (PyArray_DIM(b, 0) != d->m) { return finish_with_error(d, k, &ps, "b has incompatible dimension with A"); } - ps.b = get_contiguous(b, scs_float_type); + ps.b = scs_get_contiguous(b, scs_float_type); d->b = (scs_float *)PyArray_DATA(ps.b); if (get_pos_int_param("f", &(k->f), 0, cone) < 0) { @@ -449,6 +454,11 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) { } #ifdef PYTHON_LINSYS + if (!PyCallable_Check(scs_init_lin_sys_work_cb)) { + PyErr_SetString(PyExc_ValueError, "scs_init_lin_sys_work_cb not a valid callback"); + return SCS_NULL; + } + if (!PyCallable_Check(scs_solve_lin_sys_cb)) { PyErr_SetString(PyExc_ValueError, "scs_solve_lin_sys_cb not a valid callback"); return SCS_NULL; @@ -463,6 +473,16 @@ static PyObject *csolve(PyObject *self, PyObject *args, PyObject *kwargs) { PyErr_SetString(PyExc_ValueError, "scs_accum_by_atrans_cb not a valid callback"); return SCS_NULL; } + + if (!PyCallable_Check(scs_normalize_a_cb)) { + PyErr_SetString(PyExc_ValueError, "scs_normalize_a_cb not a valid callback"); + return SCS_NULL; + } + + if (!PyCallable_Check(scs_un_normalize_a_cb)) { + PyErr_SetString(PyExc_ValueError, "scs_un_normalize_a_cb not a valid callback"); + return SCS_NULL; + } #endif