diff --git a/quimb/experimental/belief_propagation/l1bp.py b/quimb/experimental/belief_propagation/l1bp.py index f7458232..3c4d56c6 100644 --- a/quimb/experimental/belief_propagation/l1bp.py +++ b/quimb/experimental/belief_propagation/l1bp.py @@ -170,13 +170,16 @@ def _update_m(key, data): if self.update == "parallel": new_data = {} + # compute all new messages while self.touched: key = self.touched.pop() new_data[key] = _compute_m(key) + # insert all new messages for key, data in new_data.items(): _update_m(key, data) elif self.update == "sequential": + # compute each new message and immediately re-insert it while self.touched: key = self.touched.pop() data = _compute_m(key) diff --git a/quimb/experimental/belief_propagation/l2bp.py b/quimb/experimental/belief_propagation/l2bp.py index 4aeed0bf..3eb04e18 100644 --- a/quimb/experimental/belief_propagation/l2bp.py +++ b/quimb/experimental/belief_propagation/l2bp.py @@ -19,12 +19,14 @@ def __init__( site_tags=None, damping=0.0, local_convergence=True, + update="parallel", optimize="auto-hq", **contract_opts, ): self.backend = next(t.backend for t in tn) self.damping = damping self.local_convergence = local_convergence + self.update = update self.optimize = optimize self.contract_opts = contract_opts @@ -126,10 +128,12 @@ def iterate(self, tol=5e-6): ) ncheck = len(self.touched) + nconv = 0 + max_mdiff = -1.0 + new_touched = set() - new_data = {} - while self.touched: - i, j = self.touched.pop() + def _compute_m(key): + i, j = key bix = self.edges[(i, j) if i < j else (j, i)] cix = tuple(ix + "**" for ix in bix) output_inds = cix + bix @@ -145,12 +149,11 @@ def iterate(self, tol=5e-6): ) tm_new.modify(apply=self._symmetrize) tm_new.modify(apply=self._normalize) - # defer setting the data to do a parallel update - new_data[i, j] = tm_new.data + return tm_new.data + + def _update_m(key, data): + nonlocal nconv, max_mdiff - nconv = 0 - max_mdiff = -1.0 - for key, data in new_data.items(): tm = self.messages[key] if self.damping > 0.0: @@ -160,13 +163,32 @@ def iterate(self, tol=5e-6): if mdiff > tol: # mark touching messages for update - self.touched.update(self.touch_map[key]) + new_touched.update(self.touch_map[key]) else: nconv += 1 max_mdiff = max(max_mdiff, mdiff) tm.modify(data=data) + if self.update == "parallel": + new_data = {} + # compute all new messages + while self.touched: + key = self.touched.pop() + new_data[key] = _compute_m(key) + # insert all new messages + for key, data in new_data.items(): + _update_m(key, data) + + elif self.update == "sequential": + # compute each new message and immediately re-insert it + while self.touched: + key = self.touched.pop() + data = _compute_m(key) + _update_m(key, data) + + self.touched = new_touched + return nconv, ncheck, max_mdiff def contract(self, strip_exponent=False): diff --git a/tests/test_tensor/test_belief_propagation/test_l1bp.py b/tests/test_tensor/test_belief_propagation/test_l1bp.py index b9bdd229..705be708 100644 --- a/tests/test_tensor/test_belief_propagation/test_l1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_l1bp.py @@ -29,12 +29,15 @@ def test_contract_loopy_approx(dtype, damping): @pytest.mark.parametrize("dtype", ["float32", "complex64"]) @pytest.mark.parametrize("damping", [0.0, 0.1]) -def test_contract_double_loopy_approx(dtype, damping): +@pytest.mark.parametrize("update", ("parallel", "sequential")) +def test_contract_double_loopy_approx(dtype, damping, update): peps = qtn.PEPS.rand(4, 3, 2, seed=42, dtype=dtype) tn = peps.H & peps Z_ex = tn.contract() info = {} - Z_bp1 = contract_l1bp(tn, damping=damping, info=info, progbar=True) + Z_bp1 = contract_l1bp( + tn, damping=damping, update=update, info=info, progbar=True + ) assert info["converged"] assert Z_bp1 == pytest.approx(Z_ex, rel=0.3) # compare with 2-norm BP on the peps directly diff --git a/tests/test_tensor/test_belief_propagation/test_l2bp.py b/tests/test_tensor/test_belief_propagation/test_l2bp.py index 258404fd..e669d3b8 100644 --- a/tests/test_tensor/test_belief_propagation/test_l2bp.py +++ b/tests/test_tensor/test_belief_propagation/test_l2bp.py @@ -69,7 +69,8 @@ def test_contract_double_layer_tree_exact(dtype): @pytest.mark.parametrize("dtype", ["float32", "complex64"]) @pytest.mark.parametrize("damping", [0.0, 0.1]) -def test_compress_double_layer_loopy(dtype, damping): +@pytest.mark.parametrize("update", ["parallel", "sequential"]) +def test_compress_double_layer_loopy(dtype, damping, update): peps = qtn.PEPS.rand(3, 4, bond_dim=3, seed=42, dtype=dtype) pepo = qtn.PEPO.rand(3, 4, bond_dim=2, seed=42, dtype=dtype) @@ -85,7 +86,12 @@ def test_compress_double_layer_loopy(dtype, damping): # compress using BP info = {} tn_bp = compress_l2bp( - tn_lazy, max_bond=3, damping=damping, info=info, progbar=True + tn_lazy, + max_bond=3, + damping=damping, + update=update, + info=info, + progbar=True, ) assert info["converged"] assert tn_bp.num_tensors == 12