From 6503181b5f1c72ed3c0f1c19c498710029ff41cd Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Wed, 25 Sep 2024 18:21:34 +0100 Subject: [PATCH] BUG: expand_dims: handle positive/negative duplicates --- src/array_api_extra/_funcs.py | 13 +++++++++---- tests/test_funcs.py | 6 ++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 1b0030b..a67f778 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -65,7 +65,9 @@ def expand_dims( a : array axis : int or tuple of ints Position(s) in the expanded axes where the new axis (or axes) is/are placed. - If multiple positions are provided, they should be unique. + If multiple positions are provided, they should be unique (note that a position + given by a positive index could also be referred to by a negative index - + that will also result in an error). Default: ``(0,)``. xp : array_namespace The standard-compatible namespace for `a`. @@ -114,9 +116,6 @@ def expand_dims( """ if not isinstance(axis, tuple): axis = (axis,) - if len(set(axis)) != len(axis): - err_msg = "Duplicate dimensions specified in `axis`." - raise ValueError(err_msg) ndim = a.ndim + len(axis) if axis != () and (min(axis) < -ndim or max(axis) >= ndim): err_msg = ( @@ -124,6 +123,12 @@ def expand_dims( ) raise IndexError(err_msg) axis = tuple(dim % ndim for dim in axis) + if len(set(axis)) != len(axis): + err_msg = "Duplicate dimensions specified in `axis`." + raise ValueError(err_msg) + if len(set(axis)) != len(axis): + err_msg = "Duplicate dimensions specified in `axis`." + raise ValueError(err_msg) for i in sorted(axis): a = xp.expand_dims(a, axis=i) return a diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 1d3264d..6ed3278 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -182,3 +182,9 @@ def test_repeated_axis(self): a = xp.empty((3, 3, 3)) with pytest.raises(ValueError, match="Duplicate dimensions"): expand_dims(a, axis=(1, 1), xp=xp) + + def test_positive_negative_repeated(self): + # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817 + a = xp.empty((2, 3, 4, 5)) + with pytest.raises(ValueError, match="Duplicate dimensions"): + expand_dims(a, axis=(3, -3), xp=xp)