Skip to content

Commit

Permalink
Support multidimensional boolean set/inc_subtensor in Numba via rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 28, 2024
1 parent 9dad122 commit f94a44c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
38 changes: 29 additions & 9 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
This is only done when there's a single vector index.
"""

if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates:
if node.op.ignore_duplicates:
# `AdvancedIncSubtensor1` does not ignore duplicate index values
return

Expand Down Expand Up @@ -1967,19 +1967,26 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
return new_out


@node_rewriter(tracks=[AdvancedSubtensor])
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
x[eye(3, dtype=bool)].set(1) -> x.ravel()[eye(3).ravel()].set(1).reshape(x.shape)
"""
x, *idxs = node.inputs
if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs
else:
x, y, *idxs = node.inputs

if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
(
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int"))
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indexes
# Get out if there are any other advanced indexes or np.newaxis
return None

bool_idxs = [
Expand Down Expand Up @@ -2007,7 +2014,16 @@ def ravel_multidimensional_bool_idx(fgraph, node):
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx

return [raveled_x[tuple(new_idxs)]]
if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs)
else:
# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = new_out.reshape(x_shape)

return [copy_stack_trace(node.outputs[0], new_out)]


@node_rewriter(tracks=[AdvancedSubtensor])
Expand All @@ -2024,10 +2040,13 @@ def ravel_multidimensional_int_idx(fgraph, node):
x, *idxs = node.inputs

if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
(
(isinstance(idx.type, TensorType) and idx.type.dtype == "bool")
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indexes
# Get out if there are any other advanced indexes or np.newaxis
return None

int_idxs = [
Expand Down Expand Up @@ -2059,7 +2078,8 @@ def ravel_multidimensional_int_idx(fgraph, node):
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
)
return [raveled_subtensor.reshape(unraveled_shape)]
new_out = raveled_subtensor.reshape(unraveled_shape)
return [copy_stack_trace(node.outputs[0], new_out)]


optdb["specialize"].register(
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,10 +1456,10 @@ def inc_subtensor(
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether or not ``x[indices] += y`` is used or
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``. When the special duplicates handling of
``np.add.at`` isn't required, setting this option to ``True``
(i.e. using ``x[indices] += y``) can resulting in faster compiled
(i.e. using ``x[indices] += y``) can result in faster compiled
graphs.
Examples
Expand Down
32 changes: 27 additions & 5 deletions tests/link/numba/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,19 @@ def test_AdvancedIncSubtensor1(x, y, indices):
-np.arange(3),
(np.eye(3).astype(bool)), # Boolean index
False,
True,
True,
False,
False,
),
(
np.arange(3 * 3 * 5).reshape((3, 3, 5)),
rng.poisson(size=(3, 2)),
(
np.eye(3).astype(bool),
slice(-2, None),
), # Boolean index, mixed with basic index
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
Expand Down Expand Up @@ -394,10 +405,18 @@ def test_AdvancedIncSubtensor1(x, y, indices):
rng.poisson(size=(2, 2)),
([[1, 2], [2, 3]]), # matrix indices
False,
False, # Gets converted to AdvancedIncSubtensor1
True, # This is actually supported with the default `ignore_duplicates=False`
),
(
np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)),
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index
False,
True,
True,
),
pytest.param(
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 5)),
([1, 1], [2, 2]), # Repeated indices
Expand All @@ -418,6 +437,9 @@ def test_AdvancedIncSubtensor(
inc_requires_objmode,
inplace,
):
# Need rewrite to support certain forms of advanced indexing without object mode
mode = numba_mode.including("specialize")

x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y")

Expand All @@ -432,7 +454,7 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)

if inplace:
# Test updates inplace
Expand All @@ -452,7 +474,7 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
if inplace:
# Test updates inplace
x_orig = x.copy()
Expand Down

0 comments on commit f94a44c

Please sign in to comment.