Skip to content

Commit

Permalink
BUG: expand_dims: handle positive/negative duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Sep 25, 2024
1 parent 0a0bed1 commit 6503181
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def expand_dims(
a : array
axis : int or tuple of ints
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
If multiple positions are provided, they should be unique.
If multiple positions are provided, they should be unique (note that a position
given by a positive index could also be referred to by a negative index -
that will also result in an error).
Default: ``(0,)``.
xp : array_namespace
The standard-compatible namespace for `a`.
Expand Down Expand Up @@ -114,16 +116,19 @@ def expand_dims(
"""
if not isinstance(axis, tuple):
axis = (axis,)
if len(set(axis)) != len(axis):
err_msg = "Duplicate dimensions specified in `axis`."
raise ValueError(err_msg)
ndim = a.ndim + len(axis)
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
err_msg = (
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
)
raise IndexError(err_msg)
axis = tuple(dim % ndim for dim in axis)
if len(set(axis)) != len(axis):
err_msg = "Duplicate dimensions specified in `axis`."
raise ValueError(err_msg)
if len(set(axis)) != len(axis):
err_msg = "Duplicate dimensions specified in `axis`."
raise ValueError(err_msg)
for i in sorted(axis):
a = xp.expand_dims(a, axis=i)
return a
Expand Down
6 changes: 6 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,9 @@ def test_repeated_axis(self):
a = xp.empty((3, 3, 3))
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(1, 1), xp=xp)

def test_positive_negative_repeated(self):
# https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
a = xp.empty((2, 3, 4, 5))
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(3, -3), xp=xp)

0 comments on commit 6503181

Please sign in to comment.