From 349c0381f856953b68c8f28a554c525096e462ef Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 31 Aug 2023 13:34:56 -0400 Subject: [PATCH] Scale factors across plate dims in `partial_sum_product` (#606) --- .github/workflows/ci.yml | 2 +- funsor/ops/builtin.py | 4 ++ funsor/ops/op.py | 2 + funsor/sum_product.py | 49 +++++++++++++++++--- funsor/testing.py | 4 +- test/test_sum_product.py | 96 +++++++++++++++++++++++++++++++++++++++- test/test_terms.py | 4 +- 7 files changed, 150 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4c39cb1d..a9accdcd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,7 +67,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: [3.9] env: CI: 1 FUNSOR_BACKEND: jax diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 0f3d0c03..77fbfe6d 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -8,6 +8,7 @@ from .op import ( BINARY_INVERSES, DISTRIBUTIVE_OPS, + PRODUCT_TO_POWER, SAFE_BINARY_INVERSES, UNARY_INVERSES, UNITS, @@ -287,6 +288,9 @@ def sigmoid_log_abs_det_jacobian(x, y): UNARY_INVERSES[mul] = reciprocal UNARY_INVERSES[add] = neg +PRODUCT_TO_POWER[add] = mul +PRODUCT_TO_POWER[mul] = pow + __all__ = [ "AssociativeOp", "ComparisonOp", diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 5c5312e8..f7540c63 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -421,6 +421,7 @@ def log_abs_det_jacobian(x, y, fn): BINARY_INVERSES = {} # binary op -> inverse binary op SAFE_BINARY_INVERSES = {} # binary op -> numerically safe inverse binary op UNARY_INVERSES = {} # binary op -> inverse unary op +PRODUCT_TO_POWER = {} # product op -> power op __all__ = [ "BINARY_INVERSES", @@ -430,6 +431,7 @@ def log_abs_det_jacobian(x, y, fn): "LogAbsDetJacobianOp", "NullaryOp", "Op", + "PRODUCT_TO_POWER", "SAFE_BINARY_INVERSES", "TernaryOp", "TransformOp", diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 751fbc37..31a4d2fa 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -11,7 +11,7 @@ from funsor.cnf import Contraction from funsor.domains import Bint, Reals from funsor.interpreter import gensym -from funsor.ops import UNITS, AssociativeOp +from funsor.ops import PRODUCT_TO_POWER, UNITS, AssociativeOp from funsor.terms import ( Cat, Funsor, @@ -203,7 +203,14 @@ def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()): def partial_sum_product( - sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False + sum_op, + prod_op, + factors, + eliminate=frozenset(), + plates=frozenset(), + pedantic=False, + pow_op=None, + plate_to_scale=None, # dict ): """ Performs partial sum-product contraction of a collection of factors. @@ -218,6 +225,10 @@ def partial_sum_product( assert isinstance(eliminate, frozenset) assert isinstance(plates, frozenset) + if plate_to_scale: + if pow_op is None: + pow_op = PRODUCT_TO_POWER[prod_op] + if pedantic: var_to_errors = defaultdict(lambda: eliminate) for f in factors: @@ -256,7 +267,17 @@ def partial_sum_product( f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate) remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: - results.append(f.reduce(prod_op, leaf & eliminate)) + f = f.reduce(prod_op, leaf & eliminate) + if plate_to_scale: + f_scales = [ + plate_to_scale[plate] + for plate in leaf & eliminate + if plate in plate_to_scale + ] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) + results.append(f) else: new_plates = frozenset().union( *(var_to_ordinal[v] for v in remaining_sum_vars) @@ -306,6 +327,15 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) + if plate_to_scale: + f_scales = [ + plate_to_scale[plate] + for plate in reduced_plates + if plate in plate_to_scale + ] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) ordinal_to_factors[new_plates].append(f) return results @@ -571,7 +601,14 @@ def modified_partial_sum_product( def sum_product( - sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False + sum_op, + prod_op, + factors, + eliminate=frozenset(), + plates=frozenset(), + pedantic=False, + pow_op=None, + plate_to_scale=None, # dict ): """ Performs sum-product contraction of a collection of factors. @@ -579,7 +616,9 @@ def sum_product( :return: a single contracted Funsor. :rtype: :class:`~funsor.terms.Funsor` """ - factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates, pedantic) + factors = partial_sum_product( + sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, plate_to_scale + ) return reduce(prod_op, factors, Number(UNITS[prod_op])) diff --git a/funsor/testing.py b/funsor/testing.py index 91336a52..dbf61245 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -81,7 +81,7 @@ def id_from_inputs(inputs): @dispatch(object, object, Variadic[float]) def allclose(a, b, rtol=1e-05, atol=1e-08): - if type(a) != type(b): + if type(a) is not type(b): return False return ops.abs(a - b) < rtol + atol * ops.abs(b) @@ -125,7 +125,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): elif isinstance(actual, Gaussian): assert isinstance(expected, Gaussian) else: - assert type(actual) == type(expected), msg + assert type(actual) is type(expected), msg if isinstance(actual, Funsor): assert isinstance(expected, Funsor), msg diff --git a/test/test_sum_product.py b/test/test_sum_product.py index b288cb45..ea55d4c5 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -35,7 +35,7 @@ sum_product, ) from funsor.tensor import Tensor, get_default_prototype -from funsor.terms import Variable +from funsor.terms import Cat, Variable from funsor.testing import assert_close, random_gaussian, random_tensor from funsor.util import get_backend @@ -2899,3 +2899,97 @@ def test_mixed_sequential_sum_product(duration, num_segments): ) assert_close(actual, expected) + + +@pytest.mark.parametrize( + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], +) +@pytest.mark.parametrize("scale", [1, 2]) +def test_partial_sum_product_scale_1(sum_op, prod_op, scale): + f1 = random_tensor(OrderedDict(a=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) + + eliminate = frozenset("ai") + plates = frozenset("i") + + # Actual result based on applying scaling + factors = [f1, f2] + scales = {"i": scale} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales + ) + + # Expected result based on concatenating factors + f3 = Cat("i", (f2,) * scale) + factors = [f1, f3] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + +@pytest.mark.parametrize( + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], +) +@pytest.mark.parametrize("scale_i", [1, 2]) +@pytest.mark.parametrize("scale_j", [1, 3]) +def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j): + f1 = random_tensor(OrderedDict(a=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) + f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4])) + + eliminate = frozenset("aij") + plates = frozenset("ij") + + # Actual result based on applying scaling + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales + ) + + # Expected result based on concatenating factors + f4 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f3,) * scale_j) + factors = [f1, f4, f5] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + +@pytest.mark.parametrize( + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], +) +@pytest.mark.parametrize("scale_i", [1, 2]) +@pytest.mark.parametrize("scale_j", [1, 3]) +@pytest.mark.parametrize("scale_k", [1, 4]) +def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k): + f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3])) + f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3])) + + eliminate = frozenset("aijk") + plates = frozenset("ijk") + + # Actual result based on applying scaling + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j, "k": scale_k} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales + ) + + # Expected result based on concatenating factors + f4 = Cat("i", (f1,) * scale_i) + # concatenate across multiple dims + f5 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f5,) * scale_j) + # concatenate across multiple dims + f6 = Cat("i", (f3,) * scale_i) + f6 = Cat("j", (f6,) * scale_j) + f6 = Cat("k", (f6,) * scale_k) + factors = [f4, f5, f6] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + assert_close(actual, expected, atol=1e-4, rtol=1e-4) diff --git a/test/test_terms.py b/test/test_terms.py index db7e586b..af18dfee 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -72,7 +72,7 @@ def test_to_funsor_error(x): def test_to_data(): actual = to_data(Number(0.0)) expected = 0.0 - assert type(actual) == type(expected) + assert type(actual) is type(expected) assert actual == expected @@ -569,7 +569,7 @@ def test_stack_slice(start, stop, step): xs = tuple(map(Number, range(10))) actual = Stack("i", xs)(i=Slice("j", start, stop, step, dtype=10)) expected = Stack("j", xs[start:stop:step]) - assert type(actual) == type(expected) + assert type(actual) is type(expected) assert actual.name == expected.name assert actual.parts == expected.parts