diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index f3b62038..97c584be 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -5,7 +5,11 @@ from typing import Literal, Optional, Sequence, Tuple, Union from ._typing import ndarray -from numpy.core.numeric import normalize_axis_tuple +import numpy as np +if np.__version__[0] == "2": + from numpy.lib.array_utils import normalize_axis_tuple +else: + from numpy.core.numeric import normalize_axis_tuple from ._aliases import matmul, matrix_transpose, tensordot, vecdot from .._internal import get_xp