Skip to content

Commit

Permalink
API: make mean private
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Sep 29, 2024
1 parent 28e6f86 commit a53cfb8
Showing 1 changed file with 3 additions and 40 deletions.
43 changes: 3 additions & 40 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = mean(m, axis=1, xp=xp)
avg = _mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

if fact <= 0:
Expand All @@ -133,7 +133,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
return xp.squeeze(c, axis=axes)


def mean(
def _mean(
x: Array,
/,
*,
Expand All @@ -142,44 +142,7 @@ def mean(
xp: ModuleType,
) -> Array:
"""
Calculates the arithmetic mean of the input array ``x``.
In addition to the standard ``mean``, this function supports complex-valued input.
Parameters
----------
x: array
input array. Should have a floating-point data type.
axis: int or tuple of ints, optional
axis or axes along which arithmetic means must be computed.
By default, the mean must be computed over the entire array.
If a tuple of integers, arithmetic means must be computed over multiple axes.
Default: ``None``.
keepdims: bool, optional
if ``True``, the reduced axes (dimensions) must be included in the result as
singleton dimensions, and, accordingly, the result must be compatible with
the input array (see :ref:`broadcasting`).
Otherwise, if ``False``, the reduced axes (dimensions) must not be included
in the result. Default: ``False``.
Returns
-------
out: array
if the arithmetic mean was computed over the entire array,
a zero-dimensional array containing the arithmetic mean;
otherwise, a non-zero-dimensional array containing the arithmetic means.
The returned array must have the same data type as ``x``.
Notes
-----
**Special Cases**
Let ``N`` equal the number of elements over which to compute the arithmetic mean.
- If ``N`` is ``0``, the arithmetic mean is ``NaN``.
- If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN``
(i.e., ``NaN`` values propagate).
Complex mean, https://github.com/data-apis/array-api/issues/846.
"""
if xp.isdtype(x.dtype, "complex floating"):
x_real = xp.real(x)
Expand Down

0 comments on commit a53cfb8

Please sign in to comment.