From f627b4e838741158f3d8e965df2385d4938c1177 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Wed, 9 Oct 2024 13:28:51 -0700 Subject: [PATCH] tensor_split: use fuse --- quimb/tensor/tensor_core.py | 92 ++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index aec0197f..e3bcf8ef 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -462,33 +462,31 @@ def tensor_split( method : str, optional How to split the tensor, only some methods allow bond truncation: - - ``'svd'``: full SVD, allows truncation. - - ``'eig'``: full SVD via eigendecomp, allows truncation. - - ``'lu'``: full LU decomposition, allows truncation. This method - favors tensor sparsity but is not rank optimal. - - ``'svds'``: iterative svd, allows truncation. - - ``'isvd'``: iterative svd using interpolative methods, allows - truncation. - - ``'rsvd'`` : randomized iterative svd with truncation. - - ``'eigh'``: full eigen-decomposition, tensor must he hermitian. - - ``'eigsh'``: iterative eigen-decomposition, tensor must be - hermitian. - - ``'qr'``: full QR decomposition. - - ``'lq'``: full LR decomposition. - - ``'polar_right'``: full polar decomposition as ``A = UP``. - - ``'polar_left'``: full polar decomposition as ``A = PU``. - - ``'cholesky'``: full cholesky decomposition, tensor must be - positive. + - ``'svd'``: full SVD, allows truncation. + - ``'eig'``: full SVD via eigendecomp, allows truncation. + - ``'lu'``: full LU decomposition, allows truncation. This method + favors tensor sparsity but is not rank optimal. + - ``'svds'``: iterative svd, allows truncation. + - ``'isvd'``: iterative svd using interpolative methods, allows + truncation. + - ``'rsvd'`` : randomized iterative svd with truncation. + - ``'eigh'``: full eigen-decomposition, tensor must he hermitian. + - ``'eigsh'``: iterative eigen-decomposition, tensor must be hermitian. + - ``'qr'``: full QR decomposition. + - ``'lq'``: full LR decomposition. + - ``'polar_right'``: full polar decomposition as ``A = UP``. + - ``'polar_left'``: full polar decomposition as ``A = PU``. + - ``'cholesky'``: full cholesky decomposition, tensor must be positive. get : {None, 'arrays', 'tensors', 'values'} If given, what to return instead of a TN describing the split: - - ``None``: a tensor network of the two (or three) tensors. - - ``'arrays'``: the raw data arrays as a tuple ``(l, r)`` or - ``(l, s, r)`` depending on ``absorb``. - - ``'tensors '``: the new tensors as a tuple ``(Tl, Tr)`` or - ``(Tl, Ts, Tr)`` depending on ``absorb``. - - ``'values'``: only compute and return the singular values ``s``. + - ``None``: a tensor network of the two (or three) tensors. + - ``'arrays'``: the raw data arrays as a tuple ``(l, r)`` or + ``(l, s, r)`` depending on ``absorb``. + - ``'tensors '``: the new tensors as a tuple ``(Tl, Tr)`` or + ``(Tl, Ts, Tr)`` depending on ``absorb``. + - ``'values'``: only compute and return the singular values ``s``. absorb : {'both', 'left', 'right', None}, optional Whether to absorb the singular values into both, the left, or the right @@ -507,14 +505,14 @@ def tensor_split( cutoff_mode : {'sum2', 'rel', 'abs', 'rsum2'} Method with which to apply the cutoff threshold: - - ``'rel'``: values less than ``cutoff * s[0]`` discarded. - - ``'abs'``: values less than ``cutoff`` discarded. - - ``'sum2'``: sum squared of values discarded must be ``< cutoff``. - - ``'rsum2'``: sum squared of values discarded must be less than - ``cutoff`` times the total sum of squared values. - - ``'sum1'``: sum values discarded must be ``< cutoff``. - - ``'rsum1'``: sum of values discarded must be less than - ``cutoff`` times the total sum of values. + - ``'rel'``: values less than ``cutoff * s[0]`` discarded. + - ``'abs'``: values less than ``cutoff`` discarded. + - ``'sum2'``: sum squared of values discarded must be ``< cutoff``. + - ``'rsum2'``: sum squared of values discarded must be less than + ``cutoff`` times the total sum of squared values. + - ``'sum1'``: sum values discarded must be ``< cutoff``. + - ``'rsum1'``: sum of values discarded must be less than ``cutoff`` + times the total sum of values. renorm : {None, bool, or int}, optional Whether to renormalize the kept singular values, assuming the bond has @@ -555,6 +553,9 @@ def tensor_split( else: right_inds = tags_to_oset(right_inds) + nleft = len(left_inds) + nright = len(right_inds) + if isinstance(T, spla.LinearOperator): left_dims = T.ldims right_dims = T.rdims @@ -564,11 +565,14 @@ def tensor_split( array = T else: TT = T.transpose(*left_inds, *right_inds) - left_dims = TT.shape[: len(left_inds)] - right_dims = TT.shape[len(left_inds) :] + left_dims = TT.shape[:nleft] + right_dims = TT.shape[nleft:] - if (len(left_dims), len(right_dims)) != (1, 1): - array = do("reshape", TT.data, (prod(left_dims), prod(right_dims))) + if (nleft, nright) != (1, 1): + # need to fuse into matrix + array = do( + "fuse", TT.data, range(nleft), range(nleft, nleft + nright) + ) else: array = TT.data @@ -579,12 +583,14 @@ def tensor_split( method, cutoff, absorb, max_bond, cutoff_mode, renorm ) - # ``s`` itself will be None unless ``absorb=None`` is specified + # `s` itself will be None unless `absorb=None` is specified left, s, right = _SPLIT_FNS[method](array, **opts) - if len(left_dims) != 1: + if nleft != 1: + # unfuse dangling left indices left = do("reshape", left, (*left_dims, shape(left)[-1])) - if len(right_dims) != 1: + if nright != 1: + # unfuse dangling right indices right = do("reshape", right, (shape(right)[0], *right_dims)) if get == "arrays": @@ -3277,10 +3283,12 @@ def _tensor_network_gate_inds_basic( tl, tr = tn._inds_get(ixl, ixr) bnds_l, (bix,), bnds_r = group_inds(tl, tr) - if (len(bnds_l) <= 2) or (len(bnds_r) <= 2): - # reduce split is likely redundant (i.e. contracting pair and splitting - # just as cheap as performing QR reductions) - contract = "split" + # XXX: disable this for symmray, where reduced split is always important + # for keeping charge distributions across tensors stable + # if (len(bnds_l) <= 2) or (len(bnds_r) <= 2): + # # reduce split is likely redundant (i.e. contracting pair + # # and splitting just as cheap as performing QR reductions) + # contract = "split" if contract == "split": #