Skip to content

Commit

Permalink
addressed review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
healeyq3 committed Jul 2, 2024
1 parent 439229b commit d9a7a07
Showing 1 changed file with 25 additions and 33 deletions.
58 changes: 25 additions & 33 deletions diffcp/cone_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def DTi(i):
return xs, ys, ss, D_batch, DT_batch


def solve_only_wrapper(A, b, c, cone_dict, warm_start, mode, kwargs):
def solve_only_wrapper(A, b, c, cone_dict, warm_start, kwargs):
"""A wrapper around solve_only for the batch function"""
return solve_only(
A, b, c, cone_dict, warm_start=warm_start, mode=mode, **kwargs)
A, b, c, cone_dict, warm_start=warm_start, **kwargs)


def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1, mode="lsqr",
def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1,
warm_starts=None, **kwargs):
"""
Solves a batch of cone programs.
Expand All @@ -169,7 +169,7 @@ def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1, mode="lsqr",
to solve an optimization problem and populate the backward function (in PyTorch dialect)
during a forward pass through a cvxpylayer. However, because at inference time
gradient information is no longer desired, the limited functionality provided
by `solve_only_batch` was wanted for computational enhancements.
by `solve_only_batch` was wanted for computational efficiency.
"""
batch_size = len(As)
if warm_starts is None:
Expand All @@ -183,14 +183,14 @@ def solve_only_batch(As, bs, cs, cone_dicts, n_jobs_forward=-1, mode="lsqr",
xs, ys, ss = [], [], []
for i in range(batch_size):
x, y, s = solve_only(As[i], bs[i], cs[i], cone_dicts[i],
warm_starts[i], mode=mode, **kwargs)
warm_starts[i], **kwargs)
xs += [x]
ys += [y]
ss += [s]
else:
# thread pool
pool = ThreadPool(processes=n_jobs_forward)
args = [(A, b, c, cone_dict, warm_start, mode, kwargs) for A, b, c, cone_dict, warm_start in
args = [(A, b, c, cone_dict, warm_start, kwargs) for A, b, c, cone_dict, warm_start in
zip(As, bs, cs, cone_dicts, warm_starts)]
with threadpool_limits(limits=1):
results = pool.starmap(solve_only_wrapper, args)
Expand Down Expand Up @@ -281,8 +281,8 @@ def solve_and_derivative(A, b, c, cone_dict, warm_start=None, mode='lsqr',
return x, y, s, D, DT


def solve_only(A, b, c, cone_dict, warm_start=None, mode='lsqr',
solve_method='SCS', **kwargs):
def solve_only(A, b, c, cone_dict, warm_start=None,
solve_method='SCS', **kwargs):
"""
Solves a cone program and returns its solution.
Expand All @@ -292,8 +292,12 @@ def solve_only(A, b, c, cone_dict, warm_start=None, mode='lsqr',
This is another function which was created for the benefit of cvxpylayers.
"""
if np.isnan(A.data).any():
raise RuntimeError("Found a NaN in A.")
A.eliminate_zeros()

result = solve_internal(
A, b, c, cone_dict, warm_start=warm_start, mode=mode,
A, b, c, cone_dict, warm_start=warm_start,
solve_method=solve_method, **kwargs)
x = result["x"]
y = result["y"]
Expand All @@ -302,24 +306,7 @@ def solve_only(A, b, c, cone_dict, warm_start=None, mode='lsqr',


def solve_internal(A, b, c, cone_dict, solve_method=None,
warm_start=None, mode='lsqr', raise_on_error=True, **kwargs):
if mode not in ["dense", "lsqr", "lsmr"]:
raise ValueError("Unsupported mode {}; the supported modes are "
"'dense', 'lsqr' and 'lsmr'".format(mode))

if np.isnan(A.data).any():
raise RuntimeError("Found a NaN in A.")

'''
TODO(quill): in solve_and_derivative_internal (sdi) there are more operations
on A to compute "rows" and "columns" variables. Furthermore, when
sdi is called, A.eliminate_zeros() is performed 2x.
An alternative design would be to performs op1, op2, op3 (labeled in sdi)
before calling solve_internal, and then return the A computed here.
Perhaps this is all a non-factor, but I wanted to call attention to it
in case this was worth changing.
'''
A.eliminate_zeros()
warm_start=None, raise_on_error=True, **kwargs):

if solve_method is None:
psd_cone = ('s' in cone_dict) and (cone_dict['s'] != [])
Expand Down Expand Up @@ -524,12 +511,11 @@ def solve_internal(A, b, c, cone_dict, solve_method=None,

def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None,
warm_start=None, mode='lsqr', raise_on_error=True, **kwargs):

result = solve_internal(A, b, c, cone_dict, solve_method=solve_method,
warm_start=warm_start, mode=mode, raise_on_error=raise_on_error, **kwargs)
x = result["x"]
y = result["y"]
s = result["s"]
if mode not in ["dense", "lsqr", "lsmr"]:
raise ValueError("Unsupported mode {}; the supported modes are "
"'dense', 'lsqr' and 'lsmr'".format(mode))
if np.isnan(A.data).any():
raise RuntimeError("Found a NaN in A.")

# set explicit 0s in A to np.nan (op1)
A.data[A.data == 0] = np.nan
Expand All @@ -542,6 +528,12 @@ def solve_and_derivative_internal(A, b, c, cone_dict, solve_method=None,

# eliminate explicit zeros in A, we no longer need them
A.eliminate_zeros()

result = solve_internal(A, b, c, cone_dict, solve_method=solve_method,
warm_start=warm_start, raise_on_error=raise_on_error, **kwargs)
x = result["x"]
y = result["y"]
s = result["s"]

# pre-compute quantities for the derivative
m, n = A.shape
Expand Down

0 comments on commit d9a7a07

Please sign in to comment.