Skip to content

Commit

Permalink
ALS: allow lazy iterative solve, custom cg
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 15, 2024
1 parent 464bb36 commit 7c70fa1
Showing 1 changed file with 136 additions and 23 deletions.
159 changes: 136 additions & 23 deletions quimb/tensor/fitting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Tools for computing distances between and fitting tensor networks."""

from autoray import backend_like, dag, do
from autoray import backend_like, compose, dag, do

from ..utils import check_opt
from .tensor_core import TNLinearOperator
from .contraction import contract_strategy


Expand Down Expand Up @@ -108,7 +109,8 @@ def tensor_network_distance(
return (
do("abs", xAA + xBB - 2 * do("real", xAB))
# divide by average norm-squared of A and B
* 2 / (do("abs", xAA) + do("abs", xBB))
* 2
/ (do("abs", xAA) + do("abs", xBB))
) ** 0.5

dAB = do("abs", xAA + xBB - 2 * do("real", xAB)) ** 0.5
Expand All @@ -129,6 +131,7 @@ def tensor_network_fit_autodiff(
contract_optimize="auto-hq",
distance_method="auto",
normalized="squared",
xBB=None,
inplace=False,
progbar=False,
**kwargs,
Expand All @@ -155,6 +158,17 @@ def tensor_network_fit_autodiff(
distance_method : {'auto', 'dense', 'overlap'}, optional
Supplied to :func:`~quimb.tensor.tensor_core.tensor_network_distance`,
controls how the distance is computed.
normalized : bool or str, optional
If ``True``, then normalize the distance by the norm of the two
operators, i.e. ``D(A, B) * 2 / (|A| + |B|)``. The resulting distance
lies between 0 and 2 and is more useful for assessing convergence.
If ``'infidelity'``, compute the normalized infidelity
``1 - |<A|B>|^2 / (|A| |B|)``, which can be faster to optimize e.g.,
but does not take into account normalization.
xBB : float, optional
If you already know, have computed ``tn_target.H @ tn_target``, or
don't care about the overall scale of the norm distance, you can supply
a value here.
inplace : bool, optional
Update ``tn`` in place.
progbar : bool, optional
Expand All @@ -169,10 +183,11 @@ def tensor_network_fit_autodiff(
from .optimize import TNOptimizer
from .tensor_core import tensor_network_distance

xBB = (tn_target | tn_target.H).contract(
output_inds=(),
optimize=contract_optimize,
)
if xBB is None:
xBB = (tn_target | tn_target.H).contract(
output_inds=(),
optimize=contract_optimize,
)

tnopt = TNOptimizer(
tn=tn,
Expand All @@ -199,6 +214,55 @@ def tensor_network_fit_autodiff(
return tn


@compose
def vdot_broadcast(x, y):
return do("sum", x * do("conj", y), axis=0)


def conjugate_gradient(A, b, x0=None, tol=1e-10, maxiter=1000):
"""
Conjugate Gradient solver for complex matrices/linear operators.
Parameters
----------
A : operator_like
The matrix or linear operator.
B : array_like
The right-hand side vector.
x0 : array_like, optional
Initial guess for the solution.
tol : float, optional
Tolerance for convergence.
maxiter : int, optional
Maximum number of iterations.
Returns:
--------
x : array_like
The solution vector.
"""
if x0 is None:
x0 = do("zeros_like", b)

x = x0
r = p = b - A @ x

rsold = vdot_broadcast(r, r)

for _ in range(maxiter):
Ap = A @ p
alpha = rsold / vdot_broadcast(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
rsnew = vdot_broadcast(r, r)
if do("all", do("sqrt", rsnew)) < tol:
break
p = r + (rsnew / rsold) * p
rsold = rsnew

return x


def _tn_fit_als_core(
var_tags,
tnAA,
Expand All @@ -210,6 +274,9 @@ def _tn_fit_als_core(
enforce_pos,
pos_smudge,
solver="solve",
iterative=False,
iterative_solver="cg",
iterative_maxiter=2,
progbar=False,
):
from .tensor_core import group_inds
Expand Down Expand Up @@ -249,22 +316,60 @@ def _tn_fit_als_core(
# the main iterative sweep on each tensor, locally optimizing
for _ in pbar:
for tk, tb, lix, bix, rix, A_tn, y_tn in env_contractions:
# form local normalization and local overlap
Ni = A_tn.to_dense(rix, lix)
bi = y_tn.to_dense(rix, bix)

if enforce_pos:
el, V = do("linalg.eigh", Ni)
elmax = do("max", el)
el = do("clip", el, elmax * pos_smudge, None)
# can solve directly using eigendecomposition
x = V @ ((dag(V) @ bi) / do("reshape", el, (-1, 1)))
if iterative:
Ni = TNLinearOperator(
A_tn,
left_inds=rix + bix,
right_inds=lix + bix,
ldims=tb.shape,
rdims=tk.shape,
)
bi = y_tn.to_dense((*rix, *bix))
x0 = tk.to_dense((*lix, *bix))

if iterative_solver is None:
x = conjugate_gradient(
Ni, bi, x0=x0, tol=tol, maxiter=iterative_maxiter
)
else:
x = do(
f"scipy.sparse.linalg.{iterative_solver}",
Ni,
bi,
x0=x0,
rtol=tol,
maxiter=iterative_maxiter,
)[0]

else:
Ni_p = Ni
if solver == "solve":
x = do("linalg.solve", Ni_p, bi)
elif solver == "lstsq":
x = do("linalg.lstsq", Ni_p, bi, rcond=pos_smudge)[0]
# form local normalization and local overlap
Ni = A_tn.to_dense(rix, lix)
bi = y_tn.to_dense(rix, bix)

if enforce_pos:
el, V = do("linalg.eigh", Ni)
elmax = do("max", el)
el = do("clip", el, elmax * pos_smudge, None)
# can solve directly using eigendecomposition
x = V @ ((dag(V) @ bi) / do("reshape", el, (-1, 1)))
else:
Ni_p = Ni

if solver is None:
x0 = tk.to_dense(lix, bix)
x = conjugate_gradient(
Ni_p,
bi,
x0=x0,
tol=tol,
maxiter=iterative_maxiter,
)
if solver == "solve":
x = do("linalg.solve", Ni_p, bi)
elif solver == "lstsq":
x = do("linalg.lstsq", Ni_p, bi, rcond=pos_smudge)[
0
]

x_r = do("reshape", x, tk.shape)
# n.b. because we are using virtual TNs -> updates propagate
Expand All @@ -274,8 +379,14 @@ def _tn_fit_als_core(
# assess | A - B | (normalized) for convergence or printing
if (tol != 0.0) or progbar:
dagx = dag(x)
xAA = do("trace", do("real", dagx @ (Ni @ x))) # <A|A>
xAB = do("trace", do("real", dagx @ bi)) # <A|B>

if x.ndim == 2:
xAA = do("trace", do("real", dagx @ (Ni @ x))) # <A|A>
xAB = do("trace", do("real", dagx @ bi)) # <A|B>
else:
xAA = do("real", dagx @ (Ni @ x))
xAB = do("real", dagx @ bi)

d = abs(xAA + xBB - 2 * xAB) ** 0.5 * 2 / (xAA**0.5 + xBB**0.5)
if abs(d - old_d) < tol:
break
Expand All @@ -300,6 +411,7 @@ def tensor_network_fit_als(
contract_optimize="greedy",
inplace=False,
progbar=False,
**kwargs,
):
"""Optimize the fit of ``tn`` with respect to ``tn_target`` using
alternating least squares (ALS). This minimizes the norm of the difference
Expand Down Expand Up @@ -399,6 +511,7 @@ def tensor_network_fit_als(
pos_smudge=pos_smudge,
solver=solver,
progbar=progbar,
**kwargs,
)

if not inplace:
Expand Down

0 comments on commit 7c70fa1

Please sign in to comment.