diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2d2eb14..12fa1b5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ ubuntu-20.04, macos-11, windows-2022 ] + os: [ ubuntu-20.04, macos-12, windows-2022 ] python-version: [ 3.8, 3.9, "3.10", "3.11" ] env: @@ -64,13 +64,13 @@ jobs: strategy: fail-fast: false matrix: - os: [ ubuntu-20.04, macos-11, windows-2022 ] - python-version: [ 3.7, 3.9, "3.10", "3.11" ] + os: [ ubuntu-20.04, macos-12, windows-2022 ] + python-version: [ 3.8, 3.9, "3.10", "3.11" ] include: - os: ubuntu-20.04 python-version: 3.8 single_action_config: "True" - - os: macos-11 + - os: macos-12 python-version: 3.8 - os: windows-2019 python-version: 3.8 diff --git a/diffcp/cone_program.py b/diffcp/cone_program.py index acdd5a6..9919108 100644 --- a/diffcp/cone_program.py +++ b/diffcp/cone_program.py @@ -146,6 +146,62 @@ def DTi(i): return xs, ys, ss, D_batch, DT_batch +def solve_only_wrapper(A, b, c, cone_dict, warm_start, kwargs): + """A wrapper around solve_only for the batch function""" + return solve_only( + A, b, c, cone_dict, warm_start=warm_start, **kwargs) + + +def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1, + warm_starts=None, **kwargs): + """ + Solves a batch of cone programs. + Uses a ThreadPool to perform operations across + the batch in parallel. + + For more information on the arguments and return values, + see the docstring for `solve_and_derivative_batch` function. + + This function simply contains the first half of + the functionality contained in `solve_and_derivative_batch`. + For differentiating through a cone program, this function is of no use. + This function exists because cvxpylayers utilizes `solve_and_derivative_batch` + to solve an optimization problem and populate the backward function (in PyTorch dialect) + during a forward pass through a cvxpylayer. However, because at inference time + gradient information is no longer desired, the limited functionality provided + by `solve_only_batch` was wanted for computational efficiency. + """ + batch_size = len(As) + if warm_starts is None: + warm_starts = [None] * batch_size + if n_jobs_forward == -1: + n_jobs_forward = mp.cpu_count() + n_jobs_forward = min(batch_size, n_jobs_forward) + + if n_jobs_forward == 1: + #serial + xs, ys, ss = [], [], [] + for i in range(batch_size): + x, y, s = solve_only(As[i], bs[i], cs[i], cone_dicts[i], + warm_starts[i], **kwargs) + xs += [x] + ys += [y] + ss += [s] + else: + # thread pool + pool = ThreadPool(processes=n_jobs_forward) + args = [(A, b, c, cone_dict, warm_start, kwargs) for A, b, c, cone_dict, warm_start in + zip(As, bs, cs, cone_dicts, warm_starts)] + with threadpool_limits(limits=1): + results = pool.starmap(solve_only_wrapper, args) + pool.close() + xs = [r[0] for r in results] + ys = [r[1] for r in results] + ss = [r[2] for r in results] + + return xs, ys, ss + + class SolverError(Exception): pass @@ -225,26 +281,32 @@ def solve_and_derivative(A, b, c, cone_dict, warm_start=None, mode='lsqr', return x, y, s, D, DT -def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None, - warm_start=None, mode='lsqr', raise_on_error=True, **kwargs): - if mode not in ["dense", "lsqr", "lsmr"]: - raise ValueError("Unsupported mode {}; the supported modes are " - "'dense', 'lsqr' and 'lsmr'".format(mode)) +def solve_only(A, b, c, cone_dict, warm_start=None, + solve_method='SCS', **kwargs): + """ + Solves a cone program and returns its solution. + + For more information on the arguments and return values, + see the docstring for `solve_and_derivative` function. However, note + that only x, y, and s are being returned from this function. + This is another function which was created for the benefit of cvxpylayers. + """ if np.isnan(A.data).any(): raise RuntimeError("Found a NaN in A.") + A.eliminate_zeros() + + result = solve_internal( + A, b, c, cone_dict, warm_start=warm_start, + solve_method=solve_method, **kwargs) + x = result["x"] + y = result["y"] + s = result["s"] + return x, y, s - # set explicit 0s in A to np.nan - A.data[A.data == 0] = np.nan - - # compute rows and cols of nonzeros in A - rows, cols = A.nonzero() - - # reset np.nan entries in A to 0.0 - A.data[np.isnan(A.data)] = 0.0 - # eliminate explicit zeros in A, we no longer need them - A.eliminate_zeros() +def solve_internal(A, b, c, cone_dict, solve_method=None, + warm_start=None, raise_on_error=True, **kwargs): if solve_method is None: psd_cone = ('s' in cone_dict) and (cone_dict['s'] != []) @@ -304,10 +366,6 @@ def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None, result["DT"] = None return result - x = result["x"] - y = result["y"] - s = result["s"] - elif solve_method == "ECOS": if warm_start is not None: raise ValueError('ECOS does not support warmstart.') @@ -430,8 +488,6 @@ def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None, result["y"] = np.array(solution.z) result["s"] = np.array(solution.s) - x, y, s = result["x"], result["y"], result["s"] - CLARABEL2SCS_STATUS_MAP = { "Solved": "Solved", "PrimalInfeasible": "Infeasible", @@ -450,7 +506,35 @@ def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None, } else: raise ValueError("Solver %s not supported." % solve_method) + + return result + +def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None, + warm_start=None, mode='lsqr', raise_on_error=True, **kwargs): + if mode not in ["dense", "lsqr", "lsmr"]: + raise ValueError("Unsupported mode {}; the supported modes are " + "'dense', 'lsqr' and 'lsmr'".format(mode)) + if np.isnan(A.data).any(): + raise RuntimeError("Found a NaN in A.") + + # set explicit 0s in A to np.nan (op1) + A.data[A.data == 0] = np.nan + + # compute rows and cols of nonzeros in A (op2) + rows, cols = A.nonzero() + + # reset np.nan entries in A to 0.0 (op3) + A.data[np.isnan(A.data)] = 0.0 + # eliminate explicit zeros in A, we no longer need them + A.eliminate_zeros() + + result = solve_internal(A, b, c, cone_dict, solve_method=solve_method, + warm_start=warm_start, raise_on_error=raise_on_error, **kwargs) + x = result["x"] + y = result["y"] + s = result["s"] + # pre-compute quantities for the derivative m, n = A.shape N = m + n + 1