From a53cfb81f840888dbef2a8df2b6d6dd1b2ec9cf1 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 29 Sep 2024 15:34:32 +0100 Subject: [PATCH] API: make `mean` private --- src/array_api_extra/_funcs.py | 43 +++-------------------------------- 1 file changed, 3 insertions(+), 40 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index c1268e9..917e46c 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -116,7 +116,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - avg = mean(m, axis=1, xp=xp) + avg = _mean(m, axis=1, xp=xp) fact = m.shape[1] - 1 if fact <= 0: @@ -133,7 +133,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: return xp.squeeze(c, axis=axes) -def mean( +def _mean( x: Array, /, *, @@ -142,44 +142,7 @@ def mean( xp: ModuleType, ) -> Array: """ - Calculates the arithmetic mean of the input array ``x``. - - In addition to the standard ``mean``, this function supports complex-valued input. - - Parameters - ---------- - x: array - input array. Should have a floating-point data type. - axis: int or tuple of ints, optional - axis or axes along which arithmetic means must be computed. - By default, the mean must be computed over the entire array. - If a tuple of integers, arithmetic means must be computed over multiple axes. - Default: ``None``. - keepdims: bool, optional - if ``True``, the reduced axes (dimensions) must be included in the result as - singleton dimensions, and, accordingly, the result must be compatible with - the input array (see :ref:`broadcasting`). - Otherwise, if ``False``, the reduced axes (dimensions) must not be included - in the result. Default: ``False``. - - Returns - ------- - out: array - if the arithmetic mean was computed over the entire array, - a zero-dimensional array containing the arithmetic mean; - otherwise, a non-zero-dimensional array containing the arithmetic means. - The returned array must have the same data type as ``x``. - - Notes - ----- - - **Special Cases** - - Let ``N`` equal the number of elements over which to compute the arithmetic mean. - - - If ``N`` is ``0``, the arithmetic mean is ``NaN``. - - If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN`` - (i.e., ``NaN`` values propagate). + Complex mean, https://github.com/data-apis/array-api/issues/846. """ if xp.isdtype(x.dtype, "complex floating"): x_real = xp.real(x)