Skip to content

Commit

Permalink
add tensor_1d_compress.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Mar 20, 2024
1 parent 72550bc commit cb8979c
Show file tree
Hide file tree
Showing 5 changed files with 1,867 additions and 37 deletions.
2 changes: 2 additions & 0 deletions quimb/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
NNI_ham_XY,
SpinHam,
SpinHam1D,
TN1D_matching,
TN2D_classical_ising_partition_function,
TN2D_corner_double_line,
TN2D_embedded_classical_ising_partition_function,
Expand Down Expand Up @@ -361,6 +362,7 @@
"TN_rand_from_edges",
"TN_rand_reg",
"TN_rand_tree",
"TN1D_matching",
"TN2D_classical_ising_partition_function",
"TN2D_corner_double_line",
"TN2D_embedded_classical_ising_partition_function",
Expand Down
102 changes: 73 additions & 29 deletions quimb/tensor/tensor_1d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Classes and algorithms related to 1D tensor networks.
"""
"""Classes and algorithms related to 1D tensor networks."""

import operator
import functools
Expand Down Expand Up @@ -80,27 +79,15 @@ def expec_TN_1D(*tns, compress=None, eps=1e-15):
return expec_tn ^ ...


_VALID_GATE_CONTRACT = {
False,
True,
"swap+split",
"split-gate",
"swap-split-gate",
"auto-split-gate",
}
_VALID_GATE_PROPAGATE = {"sites", "register", False, True}
_TWO_BODY_ONLY = _VALID_GATE_CONTRACT - {True, False}


def maybe_factor_gate_into_tensor(G, dp, ng, where):
def maybe_factor_gate_into_tensor(G, phys_dim, nsites, where):
# allow gate to be a matrix as long as it factorizes into tensor
shape_matches_2d = (ops.ndim(G) == 2) and (G.shape[1] == dp**ng)
shape_matches_nd = all(d == dp for d in G.shape)
shape_matches_2d = (ops.ndim(G) == 2) and (G.shape[1] == phys_dim**nsites)
shape_matches_nd = all(d == phys_dim for d in G.shape)

if shape_matches_2d:
G = ops.asarray(G)
if ng >= 2:
G = reshape(G, [dp] * 2 * ng)
if nsites >= 2:
G = reshape(G, [phys_dim] * 2 * nsites)

elif not shape_matches_nd:
raise ValueError(
Expand Down Expand Up @@ -1126,8 +1113,7 @@ def compress_site(
self.right_compress_site(i + 1, bra=bra, **compress_opts)

def bond(self, i, j):
"""Get the name of the index defining the bond between sites i and j.
"""
"""Get the name of the index defining the bond between sites i and j."""
(bond,) = self[i].bonds(self[j])
return bond

Expand Down Expand Up @@ -1481,7 +1467,12 @@ def from_fill_fn(
else:
phys_dims = itertools.cycle(phys_dim)

mps = TensorNetwork()
mps = cls.new(
L=L,
cyclic=cyclic,
site_ind_id=site_ind_id,
site_tag_id=site_tag_id,
)
global_tags = tags_to_oset(tags)
bonds = collections.defaultdict(rand_uuid)

Expand All @@ -1504,13 +1495,7 @@ def from_fill_fn(
tags = global_tags | oset((site_tag_id.format(i),))
mps |= Tensor(data, inds=inds, tags=tags)

return mps.view_as_(
cls,
L=L,
cyclic=cyclic,
site_ind_id=site_ind_id,
site_tag_id=site_tag_id,
)
return mps

@classmethod
def from_dense(
Expand Down Expand Up @@ -2887,6 +2872,65 @@ def gen_tensors():

super().__init__(gen_tensors(), virtual=True, **tn_opts)

@classmethod
def from_fill_fn(
cls,
fill_fn,
L,
bond_dim,
phys_dim=2,
cyclic=False,
shape="lrud",
site_tag_id="I{}",
tags=None,
upper_ind_id="k{}",
lower_ind_id="b{}",
):
if set(shape) - {"l", "r", "u", "d"}:
raise ValueError(f"Invalid shape string: {shape}.")

# check for site varying physical dimensions
if isinstance(phys_dim, Integral):
phys_dims = itertools.repeat(phys_dim)
else:
phys_dims = itertools.cycle(phys_dim)

mpo = cls.new(
L=L,
cyclic=cyclic,
site_tag_id=site_tag_id,
upper_ind_id=upper_ind_id,
lower_ind_id=lower_ind_id,
)

global_tags = tags_to_oset(tags)
bonds = collections.defaultdict(rand_uuid)

for i in range(L):
p = next(phys_dims)
inds = []
data_shape = []
for c in shape:
if c == "l":
if (i - 1) >= 0 or cyclic:
inds.append(bonds[frozenset([(i - 1) % L, i])])
data_shape.append(bond_dim)
elif c == "r":
if (i + 1) < L or cyclic:
inds.append(bonds[frozenset([i, (i + 1) % L])])
data_shape.append(bond_dim)
elif c == "u":
inds.append(upper_ind_id.format(i))
data_shape.append(p)
else: # c == "d"
inds.append(lower_ind_id.format(i))
data_shape.append(p)
data = fill_fn(data_shape)
tags = global_tags | oset((site_tag_id.format(i),))
mpo |= Tensor(data, inds=inds, tags=tags)

return mpo

def add_MPO(self, other, inplace=False, compress=False, **compress_opts):
"""Add another MatrixProductState to this one."""
if self.L != other.L:
Expand Down
Loading

0 comments on commit cb8979c

Please sign in to comment.