Skip to content

Commit

Permalink
Merge pull request #64 from healeyq3/master
Browse files Browse the repository at this point in the history
refactored cone_program for solve_only
  • Loading branch information
PTNobel authored Jul 3, 2024
2 parents cdef024 + 095d910 commit c943f38
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 25 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-20.04, macos-11, windows-2022 ]
os: [ ubuntu-20.04, macos-12, windows-2022 ]
python-version: [ 3.8, 3.9, "3.10", "3.11" ]

env:
Expand Down Expand Up @@ -64,13 +64,13 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-20.04, macos-11, windows-2022 ]
python-version: [ 3.7, 3.9, "3.10", "3.11" ]
os: [ ubuntu-20.04, macos-12, windows-2022 ]
python-version: [ 3.8, 3.9, "3.10", "3.11" ]
include:
- os: ubuntu-20.04
python-version: 3.8
single_action_config: "True"
- os: macos-11
- os: macos-12
python-version: 3.8
- os: windows-2019
python-version: 3.8
Expand Down
126 changes: 105 additions & 21 deletions diffcp/cone_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,62 @@ def DTi(i):
return xs, ys, ss, D_batch, DT_batch


def solve_only_wrapper(A, b, c, cone_dict, warm_start, kwargs):
"""A wrapper around solve_only for the batch function"""
return solve_only(
A, b, c, cone_dict, warm_start=warm_start, **kwargs)


def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1,
warm_starts=None, **kwargs):
"""
Solves a batch of cone programs.
Uses a ThreadPool to perform operations across
the batch in parallel.
For more information on the arguments and return values,
see the docstring for `solve_and_derivative_batch` function.
This function simply contains the first half of
the functionality contained in `solve_and_derivative_batch`.
For differentiating through a cone program, this function is of no use.
This function exists because cvxpylayers utilizes `solve_and_derivative_batch`
to solve an optimization problem and populate the backward function (in PyTorch dialect)
during a forward pass through a cvxpylayer. However, because at inference time
gradient information is no longer desired, the limited functionality provided
by `solve_only_batch` was wanted for computational efficiency.
"""
batch_size = len(As)
if warm_starts is None:
warm_starts = [None] * batch_size
if n_jobs_forward == -1:
n_jobs_forward = mp.cpu_count()
n_jobs_forward = min(batch_size, n_jobs_forward)

if n_jobs_forward == 1:
#serial
xs, ys, ss = [], [], []
for i in range(batch_size):
x, y, s = solve_only(As[i], bs[i], cs[i], cone_dicts[i],
warm_starts[i], **kwargs)
xs += [x]
ys += [y]
ss += [s]
else:
# thread pool
pool = ThreadPool(processes=n_jobs_forward)
args = [(A, b, c, cone_dict, warm_start, kwargs) for A, b, c, cone_dict, warm_start in
zip(As, bs, cs, cone_dicts, warm_starts)]
with threadpool_limits(limits=1):
results = pool.starmap(solve_only_wrapper, args)
pool.close()
xs = [r[0] for r in results]
ys = [r[1] for r in results]
ss = [r[2] for r in results]

return xs, ys, ss


class SolverError(Exception):
pass

Expand Down Expand Up @@ -225,26 +281,32 @@ def solve_and_derivative(A, b, c, cone_dict, warm_start=None, mode='lsqr',
return x, y, s, D, DT


def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None,
warm_start=None, mode='lsqr', raise_on_error=True, **kwargs):
if mode not in ["dense", "lsqr", "lsmr"]:
raise ValueError("Unsupported mode {}; the supported modes are "
"'dense', 'lsqr' and 'lsmr'".format(mode))
def solve_only(A, b, c, cone_dict, warm_start=None,
solve_method='SCS', **kwargs):
"""
Solves a cone program and returns its solution.
For more information on the arguments and return values,
see the docstring for `solve_and_derivative` function. However, note
that only x, y, and s are being returned from this function.
This is another function which was created for the benefit of cvxpylayers.
"""
if np.isnan(A.data).any():
raise RuntimeError("Found a NaN in A.")
A.eliminate_zeros()

result = solve_internal(
A, b, c, cone_dict, warm_start=warm_start,
solve_method=solve_method, **kwargs)
x = result["x"]
y = result["y"]
s = result["s"]
return x, y, s

# set explicit 0s in A to np.nan
A.data[A.data == 0] = np.nan

# compute rows and cols of nonzeros in A
rows, cols = A.nonzero()

# reset np.nan entries in A to 0.0
A.data[np.isnan(A.data)] = 0.0

# eliminate explicit zeros in A, we no longer need them
A.eliminate_zeros()
def solve_internal(A, b, c, cone_dict, solve_method=None,
warm_start=None, raise_on_error=True, **kwargs):

if solve_method is None:
psd_cone = ('s' in cone_dict) and (cone_dict['s'] != [])
Expand Down Expand Up @@ -304,10 +366,6 @@ def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None,
result["DT"] = None
return result

x = result["x"]
y = result["y"]
s = result["s"]

elif solve_method == "ECOS":
if warm_start is not None:
raise ValueError('ECOS does not support warmstart.')
Expand Down Expand Up @@ -430,8 +488,6 @@ def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None,
result["y"] = np.array(solution.z)
result["s"] = np.array(solution.s)

x, y, s = result["x"], result["y"], result["s"]

CLARABEL2SCS_STATUS_MAP = {
"Solved": "Solved",
"PrimalInfeasible": "Infeasible",
Expand All @@ -450,7 +506,35 @@ def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None,
}
else:
raise ValueError("Solver %s not supported." % solve_method)

return result

def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None,
warm_start=None, mode='lsqr', raise_on_error=True, **kwargs):
if mode not in ["dense", "lsqr", "lsmr"]:
raise ValueError("Unsupported mode {}; the supported modes are "
"'dense', 'lsqr' and 'lsmr'".format(mode))
if np.isnan(A.data).any():
raise RuntimeError("Found a NaN in A.")

# set explicit 0s in A to np.nan (op1)
A.data[A.data == 0] = np.nan

# compute rows and cols of nonzeros in A (op2)
rows, cols = A.nonzero()

# reset np.nan entries in A to 0.0 (op3)
A.data[np.isnan(A.data)] = 0.0

# eliminate explicit zeros in A, we no longer need them
A.eliminate_zeros()

result = solve_internal(A, b, c, cone_dict, solve_method=solve_method,
warm_start=warm_start, raise_on_error=raise_on_error, **kwargs)
x = result["x"]
y = result["y"]
s = result["s"]

# pre-compute quantities for the derivative
m, n = A.shape
N = m + n + 1
Expand Down

0 comments on commit c943f38

Please sign in to comment.