From ff9eda39b6783f9b1f86c4c417cf8b99d4a4f24f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 5 Nov 2024 12:09:12 +0100 Subject: [PATCH] Include unconditional constant_fold rewrite --- pymc/pytensorf.py | 11 ++++++++--- tests/test_pytensorf.py | 4 ++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 213831c9f1..d0360e0131 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -45,6 +45,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.var import RandomGeneratorSharedVariable +from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -1057,7 +1058,7 @@ def compile_pymc( def constant_fold( xs: Sequence[TensorVariable], raise_not_constant: bool = True -) -> tuple[np.ndarray, ...]: +) -> tuple[np.ndarray | Variable, ...]: """Use constant folding to get constant values of a graph. Parameters @@ -1072,8 +1073,12 @@ def constant_fold( """ fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True) - # By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite - folded_xs = rewrite_graph(fg).outputs + # The default rewrite_graph includes a constand_folding that is not always applied. + # We use an unconditional constant_folding as the last pass to ensure a thorough constant folding. + rewrite_graph(fg) + topo_unconditional_constant_folding.apply(fg) + + folded_xs = fg.outputs if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs): raise NotConstantValueError diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index b3564cac1f..564dd2ba15 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -696,6 +696,10 @@ def test_inputs_preserved(self): (out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False) assert out_shape is a + def test_constant_fold_alloc(self): + x = pt.alloc(pt.arange(5), (2, 5)) + np.testing.assert_allclose(constant_fold([x]), np.broadcast_to(np.arange(5), (2, 5))) + def test_replace_vars_in_graphs(): inp = shared(0.0, name="inp")