Skip to content

Commit

Permalink
fix L1BP with 1D TN bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 26, 2023
1 parent 948fa66 commit c29aeea
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
3 changes: 3 additions & 0 deletions quimb/experimental/belief_propagation/l1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _distance(x, y):

if message_init_function is None:
tm = tn_i.contract(
all,
output_inds=bix,
optimize=self.optimize,
drop_tags=True,
Expand Down Expand Up @@ -141,6 +142,7 @@ def _compute_m(key):
bix = self.edges[(i, j) if i < j else (j, i)]
tn_i_to_j = self.contraction_tns[i, j]
tm_new = tn_i_to_j.contract(
all,
output_inds=bix,
optimize=self.optimize,
**self.contract_opts,
Expand Down Expand Up @@ -196,6 +198,7 @@ def contract(self, strip_exponent=False):
else:
# site exists but has no neighbors
tval = tn_ic.contract(
all,
output_inds=(),
optimize=self.optimize,
**self.contract_opts,
Expand Down
7 changes: 5 additions & 2 deletions quimb/experimental/belief_propagation/l2bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _distance(x, y):
tn_i = self.local_tns[i]
tn_i2 = tn_i & tn_i.conj().reindex_(remapper)
tm = tn_i2.contract(
all,
output_inds=output_inds,
optimize=self.optimize,
drop_tags=True,
Expand Down Expand Up @@ -136,6 +137,7 @@ def iterate(self, tol=5e-6):
tn_i_to_j = self.contraction_tns[i, j]

tm_new = tn_i_to_j.contract(
all,
output_inds=output_inds,
drop_tags=True,
optimize=self.optimize,
Expand Down Expand Up @@ -186,13 +188,14 @@ def contract(self, strip_exponent=False):
)
)
tvals.append(
tni.contract(optimize=self.optimize, **self.contract_opts)
tni.contract(all, optimize=self.optimize, **self.contract_opts)
)

mvals = []
for i, j in self.edges:
mvals.append(
(self.messages[i, j] & self.messages[j, i]).contract(
all,
optimize=self.optimize,
**self.contract_opts,
)
Expand Down Expand Up @@ -253,7 +256,7 @@ def compress(
tn,
max_bond=None,
cutoff=5e-6,
cutoff_mode='rsum2',
cutoff_mode="rsum2",
renorm=0,
lazy=False,
):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tensor/test_belief_propagation/test_l1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,17 @@ def test_contract_cluster_approx():
assert info["converged"]
assert f_bp == pytest.approx(f_ex, rel=0.1)
assert abs(1 - f_ex / f_bp2) < abs(1 - f_ex / f_bp)


def test_mps():
# catch bug to do with structured contract and output inds
L = 6
psi = qtn.MPS_rand_state(L=L, seed=20, bond_dim=3)
psiG = psi.copy()
psiG.gate_(qu.pauli("X"), 5, contract=True)
expec = psi.H & psiG
O = contract_l1bp(
expec,
site_tags=[f"I{i}" for i in range(L)],
)
assert O == pytest.approx(expec ^ ..., abs=1e-6)

0 comments on commit c29aeea

Please sign in to comment.