From 8dc6727b0b854353220f9da7984baa168bce7d63 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 28 Nov 2024 14:41:57 +0100 Subject: [PATCH] Setting to false can lead to slower code on C/Numba backend which don't support np.add.at natively. Support multidimensional boolean set/inc_subtensor in Numba via rewrite --- pytensor/tensor/rewriting/subtensor.py | 38 ++++++++++++++++++++------ tests/link/numba/test_subtensor.py | 32 ++++++++++++++++++---- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index fd98eaf718..7ba1908e60 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -249,7 +249,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): This is only done when there's a single vector index. """ - if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates: + if node.op.ignore_duplicates: # `AdvancedIncSubtensor1` does not ignore duplicate index values return @@ -1967,19 +1967,26 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): return new_out -@node_rewriter(tracks=[AdvancedSubtensor]) +@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) def ravel_multidimensional_bool_idx(fgraph, node): """Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()] + x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape) """ - x, *idxs = node.inputs + if isinstance(node.op, AdvancedSubtensor): + x, *idxs = node.inputs + else: + x, y, *idxs = node.inputs if any( - isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int") + ( + (isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")) + or isinstance(idx.type, NoneTypeT) + ) for idx in idxs ): - # Get out if there are any other advanced indexes + # Get out if there are any other advanced indexes or np.newaxis return None bool_idxs = [ @@ -2007,7 +2014,16 @@ def ravel_multidimensional_bool_idx(fgraph, node): new_idxs = list(idxs) new_idxs[bool_idx_pos] = raveled_bool_idx - return [raveled_x[tuple(new_idxs)]] + if isinstance(node.op, AdvancedSubtensor): + new_out = node.op(raveled_x, *new_idxs) + else: + # The dimensions of y that correspond to the boolean indices + # must already be raveled in the original graph, so we don't need to do anything to it + new_out = node.op(raveled_x, y, *new_idxs) + # But we must reshape the output to math the original shape + new_out = new_out.reshape(x_shape) + + return [copy_stack_trace(node.outputs[0], new_out)] @node_rewriter(tracks=[AdvancedSubtensor]) @@ -2024,10 +2040,13 @@ def ravel_multidimensional_int_idx(fgraph, node): x, *idxs = node.inputs if any( - isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool") + ( + (isinstance(idx.type, TensorType) and idx.type.dtype == "bool") + or isinstance(idx.type, NoneTypeT) + ) for idx in idxs ): - # Get out if there are any other advanced indexes + # Get out if there are any other advanced indexes or np.newaxis return None int_idxs = [ @@ -2059,7 +2078,8 @@ def ravel_multidimensional_int_idx(fgraph, node): *int_idx.shape, *raveled_shape[int_idx_pos + 1 :], ) - return [raveled_subtensor.reshape(unraveled_shape)] + new_out = raveled_subtensor.reshape(unraveled_shape) + return [copy_stack_trace(node.outputs[0], new_out)] optdb["specialize"].register( diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index ea3095408b..d63445bf77 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -334,8 +334,19 @@ def test_AdvancedIncSubtensor1(x, y, indices): -np.arange(3), (np.eye(3).astype(bool)), # Boolean index False, - True, - True, + False, + False, + ), + ( + np.arange(3 * 3 * 5).reshape((3, 3, 5)), + rng.poisson(size=(3, 2)), + ( + np.eye(3).astype(bool), + slice(-2, None), + ), # Boolean index, mixed with basic index + False, + False, + False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -394,10 +405,18 @@ def test_AdvancedIncSubtensor1(x, y, indices): rng.poisson(size=(2, 2)), ([[1, 2], [2, 3]]), # matrix indices False, + False, # Gets converted to AdvancedIncSubtensor1 + True, # This is actually supported with the default `ignore_duplicates=False` + ), + ( + np.arange(3 * 5).reshape((3, 5)), + rng.poisson(size=(1, 2, 2)), + (slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index + False, True, True, ), - pytest.param( + ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), rng.poisson(size=(2, 5)), ([1, 1], [2, 2]), # Repeated indices @@ -418,6 +437,9 @@ def test_AdvancedIncSubtensor( inc_requires_objmode, inplace, ): + # Need rewrite to support certain forms of advanced indexing without object mode + mode = numba_mode.including("specialize") + x_pt = pt.as_tensor(x).type("x") y_pt = pt.as_tensor(y).type("y") @@ -432,7 +454,7 @@ def test_AdvancedIncSubtensor( if set_requires_objmode else contextlib.nullcontext() ): - fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y]) + fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) if inplace: # Test updates inplace @@ -452,7 +474,7 @@ def test_AdvancedIncSubtensor( if inc_requires_objmode else contextlib.nullcontext() ): - fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y]) + fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) if inplace: # Test updates inplace x_orig = x.copy()