Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multidimensional boolean set/inc_subtensor in Numba via rewrite #1108

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
This is only done when there's a single vector index.
"""

if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates:
if node.op.ignore_duplicates:
# `AdvancedIncSubtensor1` does not ignore duplicate index values
return

Expand Down Expand Up @@ -1967,19 +1967,26 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
return new_out


@node_rewriter(tracks=[AdvancedSubtensor])
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba

x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
"""
x, *idxs = node.inputs
if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs
else:
x, y, *idxs = node.inputs

if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
(
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int"))
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indexes
# Get out if there are any other advanced indexes or np.newaxis
return None

bool_idxs = [
Expand Down Expand Up @@ -2007,7 +2014,16 @@ def ravel_multidimensional_bool_idx(fgraph, node):
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx

return [raveled_x[tuple(new_idxs)]]
if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs)
else:
# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = new_out.reshape(x_shape)

return [copy_stack_trace(node.outputs[0], new_out)]


@node_rewriter(tracks=[AdvancedSubtensor])
Expand All @@ -2024,10 +2040,13 @@ def ravel_multidimensional_int_idx(fgraph, node):
x, *idxs = node.inputs

if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
(
(isinstance(idx.type, TensorType) and idx.type.dtype == "bool")
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indexes
# Get out if there are any other advanced indexes or np.newaxis
return None

int_idxs = [
Expand Down Expand Up @@ -2059,7 +2078,8 @@ def ravel_multidimensional_int_idx(fgraph, node):
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
)
return [raveled_subtensor.reshape(unraveled_shape)]
new_out = raveled_subtensor.reshape(unraveled_shape)
return [copy_stack_trace(node.outputs[0], new_out)]


optdb["specialize"].register(
Expand Down
7 changes: 2 additions & 5 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,11 +1456,8 @@ def inc_subtensor(
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether or not ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``. When the special duplicates handling of
``np.add.at`` isn't required, setting this option to ``True``
(i.e. using ``x[indices] += y``) can resulting in faster compiled
graphs.
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.

Examples
--------
Expand Down
32 changes: 27 additions & 5 deletions tests/link/numba/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,19 @@ def test_AdvancedIncSubtensor1(x, y, indices):
-np.arange(3),
(np.eye(3).astype(bool)), # Boolean index
False,
True,
True,
False,
False,
),
(
np.arange(3 * 3 * 5).reshape((3, 3, 5)),
rng.poisson(size=(3, 2)),
(
np.eye(3).astype(bool),
slice(-2, None),
), # Boolean index, mixed with basic index
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
Expand Down Expand Up @@ -394,10 +405,18 @@ def test_AdvancedIncSubtensor1(x, y, indices):
rng.poisson(size=(2, 2)),
([[1, 2], [2, 3]]), # matrix indices
False,
False, # Gets converted to AdvancedIncSubtensor1
True, # This is actually supported with the default `ignore_duplicates=False`
),
(
np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)),
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index
False,
True,
True,
),
pytest.param(
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 5)),
([1, 1], [2, 2]), # Repeated indices
Expand All @@ -418,6 +437,9 @@ def test_AdvancedIncSubtensor(
inc_requires_objmode,
inplace,
):
# Need rewrite to support certain forms of advanced indexing without object mode
mode = numba_mode.including("specialize")

x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y")

Expand All @@ -432,7 +454,7 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)

if inplace:
# Test updates inplace
Expand All @@ -452,7 +474,7 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
if inplace:
# Test updates inplace
x_orig = x.copy()
Expand Down
Loading