Skip to content

Commit

Permalink
Include unconditional constant_fold rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 11, 2024
1 parent 6c15185 commit ff9eda3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
Expand Down Expand Up @@ -1057,7 +1058,7 @@ def compile_pymc(

def constant_fold(
xs: Sequence[TensorVariable], raise_not_constant: bool = True
) -> tuple[np.ndarray, ...]:
) -> tuple[np.ndarray | Variable, ...]:
"""Use constant folding to get constant values of a graph.
Parameters
Expand All @@ -1072,8 +1073,12 @@ def constant_fold(
"""
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True)

# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
folded_xs = rewrite_graph(fg).outputs
# The default rewrite_graph includes a constand_folding that is not always applied.
# We use an unconditional constant_folding as the last pass to ensure a thorough constant folding.
rewrite_graph(fg)
topo_unconditional_constant_folding.apply(fg)

folded_xs = fg.outputs

if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
raise NotConstantValueError
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,10 @@ def test_inputs_preserved(self):
(out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False)
assert out_shape is a

def test_constant_fold_alloc(self):
x = pt.alloc(pt.arange(5), (2, 5))
np.testing.assert_allclose(constant_fold([x]), np.broadcast_to(np.arange(5), (2, 5)))


def test_replace_vars_in_graphs():
inp = shared(0.0, name="inp")
Expand Down

0 comments on commit ff9eda3

Please sign in to comment.