Skip to content

Commit

Permalink
ENH: add setdiff1d
Browse files Browse the repository at this point in the history
Co-authored-by: Omar Salman <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
3 people committed Nov 23, 2024
1 parent f8a2a90 commit d007f4e
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 4 deletions.
165 changes: 165 additions & 0 deletions src/array_api_extra/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
### Helpers borrowed from array-api-compat

from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import inspect
import sys
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ._typing import Array, Device

__all__ = ["device"]


# Placeholder object to represent the dask device
# when the array backend is not the CPU.
# (since it is not easy to tell which device a dask array is on)
class _dask_device: # pylint: disable=invalid-name
def __repr__(self) -> str:
return "DASK_DEVICE"


_DASK_DEVICE = _dask_device()


# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
def device(x: Array, /) -> Device:
"""
Hardware device the array data resides on.
This is equivalent to `x.device` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
This helper is included because some array libraries either do not have
the `device` attribute or include it with an incompatible API.
Parameters
----------
x: array
array instance from an array API compatible library.
Returns
-------
out: device
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
section of the array API specification).
Notes
-----
For NumPy the device is always `"cpu"`. For Dask, the device is always a
special `DASK_DEVICE` object.
See Also
--------
to_device : Move array data to a different device.
"""
if _is_numpy_array(x):
return "cpu"
if _is_dask_array(x):
# Peek at the metadata of the jax array to determine type
try:
import numpy as np # pylint: disable=import-outside-toplevel

if isinstance(x._meta, np.ndarray): # pylint: disable=protected-access
# Must be on CPU since backed by numpy
return "cpu"
except ImportError:
pass
return _DASK_DEVICE
if _is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
if inspect.ismethod(x.device):
return x.device()
return x.device
if _is_pydata_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, "device", None)
if x_device is not None:
return x_device
# Everything but DOK has this attr.
try:
inner = x.data
except AttributeError:
return "cpu"
# Return the device of the constituent array
return device(inner)
return x.device


def _is_numpy_array(x: Array) -> bool:
"""Return True if `x` is a NumPy array."""
# Avoid importing NumPy if it isn't already
if "numpy" not in sys.modules:
return False

import numpy as np # pylint: disable=import-outside-toplevel

# TODO: Should we reject ndarray subclasses?
return isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array(
x
)


def _is_dask_array(x: Array) -> bool:
"""Return True if `x` is a dask.array Array."""
# Avoid importing dask if it isn't already
if "dask.array" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import dask.array # type: ignore[import-not-found]

return isinstance(x, dask.array.Array)


def _is_jax_zero_gradient_array(x: Array) -> bool:
"""Return True if `x` is a zero-gradient array.
These arrays are a design quirk of Jax that may one day be removed.
See https://github.com/google/jax/issues/20620.
"""
if "numpy" not in sys.modules or "jax" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import jax # type: ignore[import-not-found]
import numpy as np # pylint: disable=import-outside-toplevel

return isinstance(x, np.ndarray) and x.dtype == jax.float0


def _is_jax_array(x: Array) -> bool:
"""Return True if `x` is a JAX array."""
# Avoid importing jax if it isn't already
if "jax" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import jax

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def _is_pydata_sparse_array(x: Array) -> bool:
"""Return True if `x` is an array from the `sparse` package."""

# Avoid importing jax if it isn't already
if "sparse" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import sparse # type: ignore[import-not-found]

# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)
30 changes: 28 additions & 2 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from __future__ import annotations
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import warnings
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]
from . import _utils

__all__ = [
"atleast_nd",
"cov",
"create_diagonal",
"expand_dims",
"kron",
"setdiff1d",
"sinc",
]


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
Expand Down Expand Up @@ -399,6 +409,22 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))


def setdiff1d(
x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType
) -> Array:
"""Find the set difference of two arrays.
Return the unique values in `x1` that are not in `x2`.
"""

if assume_unique:
x1 = xp.reshape(x1, (-1,))
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType) -> Array:
r"""
Return the normalized sinc function.
Expand Down
5 changes: 3 additions & 2 deletions src/array_api_extra/_typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

from types import ModuleType
from typing import Any

Array = Any # To be changed to a Protocol later (see array-api#589)
Device = Any

__all__ = ["Array", "ModuleType"]
__all__ = ["Array", "Device", "ModuleType"]
63 changes: 63 additions & 0 deletions src/array_api_extra/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ._typing import Array, ModuleType

from . import _compat

__all__ = ["in1d"]


def in1d(
x1: Array,
x2: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
xp: ModuleType,
) -> Array:
"""Checks whether each element of an array is also present in a
second array.
Returns a boolean array the same length as `x1` that is True
where an element of `x1` is in `x2` and False otherwise.
This function has been adapted using the original implementation
present in numpy:
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
"""

# This code is run to make the code significantly faster
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
if invert:
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device)
for a in x2:
mask &= x1 != a
else:
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device)
for a in x2:
mask |= x1 == a
return mask

if not assume_unique:
x1, rev_idx = xp.unique_inverse(x1)
x2 = xp.unique_values(x2)

ar = xp.concat((x1, x2))
device_ = _compat.device(ar)
# We need this to be a stable sort.
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
ret = xp.take(flag, reverse_order, axis=0)

if assume_unique:
return ret[: x1.shape[0]]
# https://github.com/pylint-dev/pylint/issues/10095
# pylint: disable=possibly-used-before-assignment
return xp.take(ret, rev_idx, axis=0)

0 comments on commit d007f4e

Please sign in to comment.