Skip to content

Commit

Permalink
tensor_split: use fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 9, 2024
1 parent b6d1868 commit f627b4e
Showing 1 changed file with 50 additions and 42 deletions.
92 changes: 50 additions & 42 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,33 +462,31 @@ def tensor_split(
method : str, optional
How to split the tensor, only some methods allow bond truncation:
- ``'svd'``: full SVD, allows truncation.
- ``'eig'``: full SVD via eigendecomp, allows truncation.
- ``'lu'``: full LU decomposition, allows truncation. This method
favors tensor sparsity but is not rank optimal.
- ``'svds'``: iterative svd, allows truncation.
- ``'isvd'``: iterative svd using interpolative methods, allows
truncation.
- ``'rsvd'`` : randomized iterative svd with truncation.
- ``'eigh'``: full eigen-decomposition, tensor must he hermitian.
- ``'eigsh'``: iterative eigen-decomposition, tensor must be
hermitian.
- ``'qr'``: full QR decomposition.
- ``'lq'``: full LR decomposition.
- ``'polar_right'``: full polar decomposition as ``A = UP``.
- ``'polar_left'``: full polar decomposition as ``A = PU``.
- ``'cholesky'``: full cholesky decomposition, tensor must be
positive.
- ``'svd'``: full SVD, allows truncation.
- ``'eig'``: full SVD via eigendecomp, allows truncation.
- ``'lu'``: full LU decomposition, allows truncation. This method
favors tensor sparsity but is not rank optimal.
- ``'svds'``: iterative svd, allows truncation.
- ``'isvd'``: iterative svd using interpolative methods, allows
truncation.
- ``'rsvd'`` : randomized iterative svd with truncation.
- ``'eigh'``: full eigen-decomposition, tensor must he hermitian.
- ``'eigsh'``: iterative eigen-decomposition, tensor must be hermitian.
- ``'qr'``: full QR decomposition.
- ``'lq'``: full LR decomposition.
- ``'polar_right'``: full polar decomposition as ``A = UP``.
- ``'polar_left'``: full polar decomposition as ``A = PU``.
- ``'cholesky'``: full cholesky decomposition, tensor must be positive.
get : {None, 'arrays', 'tensors', 'values'}
If given, what to return instead of a TN describing the split:
- ``None``: a tensor network of the two (or three) tensors.
- ``'arrays'``: the raw data arrays as a tuple ``(l, r)`` or
``(l, s, r)`` depending on ``absorb``.
- ``'tensors '``: the new tensors as a tuple ``(Tl, Tr)`` or
``(Tl, Ts, Tr)`` depending on ``absorb``.
- ``'values'``: only compute and return the singular values ``s``.
- ``None``: a tensor network of the two (or three) tensors.
- ``'arrays'``: the raw data arrays as a tuple ``(l, r)`` or
``(l, s, r)`` depending on ``absorb``.
- ``'tensors '``: the new tensors as a tuple ``(Tl, Tr)`` or
``(Tl, Ts, Tr)`` depending on ``absorb``.
- ``'values'``: only compute and return the singular values ``s``.
absorb : {'both', 'left', 'right', None}, optional
Whether to absorb the singular values into both, the left, or the right
Expand All @@ -507,14 +505,14 @@ def tensor_split(
cutoff_mode : {'sum2', 'rel', 'abs', 'rsum2'}
Method with which to apply the cutoff threshold:
- ``'rel'``: values less than ``cutoff * s[0]`` discarded.
- ``'abs'``: values less than ``cutoff`` discarded.
- ``'sum2'``: sum squared of values discarded must be ``< cutoff``.
- ``'rsum2'``: sum squared of values discarded must be less than
``cutoff`` times the total sum of squared values.
- ``'sum1'``: sum values discarded must be ``< cutoff``.
- ``'rsum1'``: sum of values discarded must be less than
``cutoff`` times the total sum of values.
- ``'rel'``: values less than ``cutoff * s[0]`` discarded.
- ``'abs'``: values less than ``cutoff`` discarded.
- ``'sum2'``: sum squared of values discarded must be ``< cutoff``.
- ``'rsum2'``: sum squared of values discarded must be less than
``cutoff`` times the total sum of squared values.
- ``'sum1'``: sum values discarded must be ``< cutoff``.
- ``'rsum1'``: sum of values discarded must be less than ``cutoff``
times the total sum of values.
renorm : {None, bool, or int}, optional
Whether to renormalize the kept singular values, assuming the bond has
Expand Down Expand Up @@ -555,6 +553,9 @@ def tensor_split(
else:
right_inds = tags_to_oset(right_inds)

nleft = len(left_inds)
nright = len(right_inds)

if isinstance(T, spla.LinearOperator):
left_dims = T.ldims
right_dims = T.rdims
Expand All @@ -564,11 +565,14 @@ def tensor_split(
array = T
else:
TT = T.transpose(*left_inds, *right_inds)
left_dims = TT.shape[: len(left_inds)]
right_dims = TT.shape[len(left_inds) :]
left_dims = TT.shape[:nleft]
right_dims = TT.shape[nleft:]

if (len(left_dims), len(right_dims)) != (1, 1):
array = do("reshape", TT.data, (prod(left_dims), prod(right_dims)))
if (nleft, nright) != (1, 1):
# need to fuse into matrix
array = do(
"fuse", TT.data, range(nleft), range(nleft, nleft + nright)
)
else:
array = TT.data

Expand All @@ -579,12 +583,14 @@ def tensor_split(
method, cutoff, absorb, max_bond, cutoff_mode, renorm
)

# ``s`` itself will be None unless ``absorb=None`` is specified
# `s` itself will be None unless `absorb=None` is specified
left, s, right = _SPLIT_FNS[method](array, **opts)

if len(left_dims) != 1:
if nleft != 1:
# unfuse dangling left indices
left = do("reshape", left, (*left_dims, shape(left)[-1]))
if len(right_dims) != 1:
if nright != 1:
# unfuse dangling right indices
right = do("reshape", right, (shape(right)[0], *right_dims))

if get == "arrays":
Expand Down Expand Up @@ -3277,10 +3283,12 @@ def _tensor_network_gate_inds_basic(
tl, tr = tn._inds_get(ixl, ixr)
bnds_l, (bix,), bnds_r = group_inds(tl, tr)

if (len(bnds_l) <= 2) or (len(bnds_r) <= 2):
# reduce split is likely redundant (i.e. contracting pair and splitting
# just as cheap as performing QR reductions)
contract = "split"
# XXX: disable this for symmray, where reduced split is always important
# for keeping charge distributions across tensors stable
# if (len(bnds_l) <= 2) or (len(bnds_r) <= 2):
# # reduce split is likely redundant (i.e. contracting pair
# # and splitting just as cheap as performing QR reductions)
# contract = "split"

if contract == "split":
#
Expand Down

0 comments on commit f627b4e

Please sign in to comment.