From ad15c54271390796f97669f1a7f14e1b60848629 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 4 Oct 2024 14:02:14 +0200 Subject: [PATCH] Add Numba implementation of Blockwise --- pytensor/link/numba/dispatch/__init__.py | 9 +- pytensor/link/numba/dispatch/blockwise.py | 88 ++++++++++++++++++ pytensor/link/numba/dispatch/random.py | 2 +- pytensor/tensor/blockwise.py | 8 ++ pytensor/tensor/rewriting/__init__.py | 1 + pytensor/tensor/rewriting/numba.py | 108 ++++++++++++++++++++++ tests/link/numba/test_basic.py | 2 +- tests/link/numba/test_blockwise.py | 59 ++++++++++++ 8 files changed, 271 insertions(+), 6 deletions(-) create mode 100644 pytensor/link/numba/dispatch/blockwise.py create mode 100644 pytensor/tensor/rewriting/numba.py create mode 100644 tests/link/numba/test_blockwise.py diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 6dd0e8211b..56a3e2c9b2 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -2,15 +2,16 @@ from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify # Load dispatch specializations -import pytensor.link.numba.dispatch.scalar -import pytensor.link.numba.dispatch.tensor_basic +import pytensor.link.numba.dispatch.blockwise +import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.extra_ops import pytensor.link.numba.dispatch.nlinalg import pytensor.link.numba.dispatch.random -import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.scan -import pytensor.link.numba.dispatch.sparse +import pytensor.link.numba.dispatch.scalar import pytensor.link.numba.dispatch.slinalg +import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor +import pytensor.link.numba.dispatch.tensor_basic # isort: on diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py new file mode 100644 index 0000000000..041d9b5c25 --- /dev/null +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -0,0 +1,88 @@ +from typing import cast + +from numba.core.extending import overload +from numba.np.unsafe.ndarray import to_fixed_tuple + +from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.dispatch.vectorize_codegen import ( + _jit_options, + _vectorized, + encode_literals, + store_core_outputs, +) +from pytensor.link.utils import compile_function_src +from pytensor.tensor import TensorVariable, get_vector_length +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape + + +@numba_funcify.register +def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): + [blockwise_node] = op.fgraph.apply_nodes + blockwise_op: Blockwise = blockwise_node.op + core_op = blockwise_op.core_op + nin = len(blockwise_node.inputs) + nout = len(blockwise_node.outputs) + core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:]) + + core_node = blockwise_op._create_dummy_core_node( + cast(tuple[TensorVariable], blockwise_node.inputs) + ) + core_op_fn = numba_funcify( + core_op, + node=core_node, + parent_node=node, + fastmath=_jit_options["fastmath"], + **kwargs, + ) + core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) + + batch_ndim = blockwise_op.batch_ndim(node) + + # numba doesn't support nested literals right now... + input_bc_patterns = encode_literals( + tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs[:nin]) + ) + output_bc_patterns = encode_literals( + tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs) + ) + output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs)) + inplace_pattern = encode_literals(()) + + # Numba does not allow a tuple generator in the Jitted function so we have to compile a helper to convert core_shapes into tuples + # Alternatively, add an Op that converts shape vectors into tuples, like we did for JAX + src = "def to_tuple(core_shapes): return (" + for i in range(nout): + src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]})," + src += ")" + + to_tuple = numba_njit( + compile_function_src( + src, + "to_tuple", + global_env={"to_fixed_tuple": to_fixed_tuple}, + ) + ) + + def blockwise_wrapper(*inputs_and_core_shapes): + inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:] + tuple_core_shapes = to_tuple(core_shapes) + return _vectorized( + core_op_fn, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + (), # constant_inputs + inputs, + tuple_core_shapes, + None, # size + ) + + def blockwise(*inputs_and_core_shapes): + raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented") + + @overload(blockwise, jit_options=_jit_options) + def ov_blockwise(*inputs_and_core_shapes): + return blockwise_wrapper + + return blockwise diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 29584daa5f..04181e8335 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -388,7 +388,7 @@ def random_wrapper(core_shape, rng, size, *dist_params): return rng, draws def random(core_shape, rng, size, *dist_params): - pass + raise NotImplementedError("Non-jitted random variable not implemented") @overload(random, jit_options=_jit_options) def ov_random(core_shape, rng, size, *dist_params): diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 8d27636536..92fdc9927e 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -441,3 +441,11 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: class OpWithCoreShape(OpFromGraph): """Generalizes an `Op` to include core shape as an additional input.""" + + +class BlockwiseWithCoreShape(OpWithCoreShape): + """Generalizes a Blockwise `Op` to include a core shape parameter.""" + + def __str__(self): + [blockwise_node] = self.fgraph.apply_nodes + return f"[{blockwise_node.op!s}]" diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index fc5c528f2d..4e75140ceb 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -9,6 +9,7 @@ import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math +import pytensor.tensor.rewriting.numba import pytensor.tensor.rewriting.ofg import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py new file mode 100644 index 0000000000..91ab131424 --- /dev/null +++ b/pytensor/tensor/rewriting/numba.py @@ -0,0 +1,108 @@ +from pytensor.compile import optdb +from pytensor.graph import node_rewriter +from pytensor.graph.basic import applys_between +from pytensor.graph.rewriting.basic import out2in +from pytensor.tensor.basic import as_tensor, constant +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape +from pytensor.tensor.rewriting.shape import ShapeFeature + + +@node_rewriter([Blockwise]) +def introduce_explicit_core_shape_blockwise(fgraph, node): + """Introduce the core shape of a Blockwise. + + We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph + that has an extra "non-functional" input that represents the core shape of the Blockwise variable. + This core_shape is used by the numba backend to pre-allocate the output array. + + If available, the core shape is extracted from the shape feature of the graph, + which has a higher change of having been simplified, optimized, constant-folded. + If missing, we fall back to the op._supp_shape_from_params method. + + This rewrite is required for the numba backend implementation of Blockwise. + + Example + ------- + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(5, None, None)) + outs = pt.linalg.svd(x, compute_uv=True) + pytensor.dprint(outs) + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A] + # └─ x [id B] + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A] + # └─ ··· + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A] + # └─ ··· + + # After the rewrite, note the new 3 core shape inputs + fn = pytensor.function([x], outs, mode="NUMBA") + fn.dprint(print_type=False) + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6 + # ├─ x [id B] + # ├─ MakeVector{dtype='int64'} [id C] 5 + # │ ├─ Shape_i{1} [id D] 2 + # │ │ └─ x [id B] + # │ └─ Shape_i{1} [id D] 2 + # │ └─ ··· + # ├─ MakeVector{dtype='int64'} [id E] 4 + # │ └─ Minimum [id F] 3 + # │ ├─ Shape_i{1} [id D] 2 + # │ │ └─ ··· + # │ └─ Shape_i{2} [id G] 0 + # │ └─ x [id B] + # └─ MakeVector{dtype='int64'} [id H] 1 + # ├─ Shape_i{2} [id G] 0 + # │ └─ ··· + # └─ Shape_i{2} [id G] 0 + # └─ ··· + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6 + # └─ ··· + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6 + # └─ ··· + """ + op: Blockwise = node.op # type: ignore[annotation-unchecked] + batch_ndim = op.batch_ndim(node) + + shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked] + if shape_feature: + core_shapes = [ + [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] + for out in node.outputs + ] + else: + input_shapes = [tuple(inp.shape) for inp in node.inputs] + core_shapes = [ + out_shape[batch_ndim:] + for out_shape in op.infer_shape(None, node, input_shapes) + ] + + core_shapes = [ + as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64") + for core_shape in core_shapes + ] + + if any( + isinstance(node.op, Blockwise) + for node in applys_between(node.inputs, core_shapes) + ): + # If Blockwise shows up in the shape graph we can't introduce the core shape + return None + + return BlockwiseWithCoreShape( + [*node.inputs, *core_shapes], + node.outputs, + destroy_map=op.destroy_map, + )(*node.inputs, *core_shapes, return_list=True) + + +optdb.register( + introduce_explicit_core_shape_blockwise.__name__, + out2in(introduce_explicit_core_shape_blockwise), + "numba", + position=100, +) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index dfadc58a69..c6e26ec26d 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -242,7 +242,7 @@ def compare_numba_and_py( Parameters ---------- fgraph - `FunctionGraph` or inputs to compare. + `FunctionGraph` or tuple(inputs, outputs) to compare. inputs Numeric inputs to be passed to the compiled graphs. assert_fn diff --git a/tests/link/numba/test_blockwise.py b/tests/link/numba/test_blockwise.py new file mode 100644 index 0000000000..ced4185e14 --- /dev/null +++ b/tests/link/numba/test_blockwise.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from pytensor import function +from pytensor.tensor import tensor +from pytensor.tensor.basic import ARange +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.nlinalg import SVD, Det +from pytensor.tensor.slinalg import Cholesky, cholesky +from tests.link.numba.test_basic import compare_numba_and_py, numba_mode + + +# Fails if object mode warning is issued when not expected +pytestmark = pytest.mark.filterwarnings("error") + + +@pytest.mark.parametrize("shape_opt", [True, False], ids=str) +@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str) +def test_blockwise(core_op, shape_opt): + x = tensor(shape=(5, None, None)) + outs = Blockwise(core_op=core_op)(x, return_list=True) + + mode = ( + numba_mode.including("ShapeOpt") + if shape_opt + else numba_mode.excluding("ShapeOpt") + ) + x_test = np.eye(3) * np.arange(1, 6)[:, None, None] + compare_numba_and_py( + ([x], outs), + [x_test], + numba_mode=mode, + eval_obj_mode=False, + ) + + +def test_non_square_blockwise(): + """Test that Op that cannot always be blockwised at runtime fails gracefully.""" + x = tensor(shape=(3,), dtype="int64") + out = Blockwise(core_op=ARange(dtype="int64"), signature="(),(),()->(a)")(0, x, 1) + + with pytest.warns(UserWarning, match="Numba will use object mode"): + fn = function([x], out, mode="NUMBA") + + np.testing.assert_allclose(fn([5, 5, 5]), np.broadcast_to(np.arange(5), (3, 5))) + + with pytest.raises(ValueError): + fn([3, 4, 5]) + + +def test_blockwise_benchmark(benchmark): + x = tensor(shape=(5, 3, 3)) + out = cholesky(x) + assert isinstance(out.owner.op, Blockwise) + + fn = function([x], out, mode="NUMBA") + x_test = np.eye(3) * np.arange(1, 6)[:, None, None] + fn(x_test) # JIT compile + benchmark(fn, x_test)