Skip to content

Commit

Permalink
stats.utils: Don't count zeros as nans
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jan 26, 2023
1 parent a25a9d2 commit 7af9325
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 31 deletions.
Empty file.
48 changes: 48 additions & 0 deletions Orange/statistics/unittests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest

import numpy as np
import scipy as sp

from Orange.statistics.util import stats


class TestStats(unittest.TestCase):
def test_stats(self):
arr = np.array([[1, 4, 9],
[-2, 10, 0],
[0, np.nan, np.nan],
[0, np.nan, 0]])

expected = [[-2, 1, -0.25, (1.25 ** 2 + 1.75 ** 2 + .25 ** 2 + .25 ** 2) / 4, 0, 4],
[4, 10, 7, 3 ** 2, 2, 2],
[0, 9, 3, (6 ** 2 + 3 ** 2 + 3 ** 2) / 3, 1, 3]]
np.testing.assert_almost_equal(stats(arr, compute_variance=True), expected)

sparr = sp.sparse.csc_matrix(arr)
np.testing.assert_almost_equal(stats(sparr, compute_variance=True), expected)

sparr = sparr.tocsr()
np.testing.assert_almost_equal(stats(sparr, compute_variance=True), expected)

weights = np.array([1, 2, 0, 3])
e0 = (1 * 1 - 2 * 2 + 0 * 0 + 3 * 0) / (1 + 2 + 0 + 3)
e1 = (1 * 4 + 2 * 10) / 3
e2 = (1 * 9 + 2 * 0 + 3 * 0) / 6
expected = [[-2, 1, e0, ((e0 - 1) ** 2 + 2 * (e0 + 2) ** 2 + 3 * e0 ** 2) / 6, 0, 4],
[4, 10, e1, ((e1 - 4) ** 2 + 2 * (e1 - 10) ** 2) / 3, 2, 2],
[0, 9, e2, ((e2 - 9) ** 2 + 2 * e2 ** 2 + 3 * e2 ** 2) / 6, 1, 3]]

np.testing.assert_almost_equal(
stats(arr, weights=weights, compute_variance=True), expected)

sparr = sp.sparse.csc_matrix(arr)
np.testing.assert_almost_equal(
stats(sparr, weights=weights, compute_variance=True), expected)

sparr = sparr.tocsr()
np.testing.assert_almost_equal(
stats(sparr, weights=weights, compute_variance=True), expected)


if __name__ == "__main__":
unittest.main()
19 changes: 5 additions & 14 deletions Orange/statistics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,31 +344,22 @@ def stats(X, weights=None, compute_variance=False):

if X.size and is_numeric:
if is_sparse:
nans = countnans(X, axis=0)
X = X.tocsc()
else:
nans = np.isnan(X).sum(axis=0)
if compute_variance:
means, vars = nan_mean_var(X, axis=0, weights=weights)
else:
means = nanmean(X, axis=0, weights=weights)
vars = np.zeros(X.shape[1] if X.ndim == 2 else 1)

if X.size and is_numeric and not is_sparse:
nans = np.isnan(X).sum(axis=0)
return np.column_stack((
np.nanmin(X, axis=0),
np.nanmax(X, axis=0),
means,
vars,
nans,
X.shape[0] - nans))
elif is_sparse and X.size:
non_zero = np.bincount(X.nonzero()[1], minlength=X.shape[1])
return np.column_stack((
nanmin(X, axis=0),
nanmax(X, axis=0),
means,
vars,
X.shape[0] - non_zero,
non_zero))
nans,
X.shape[0] - nans))
else:
if X.ndim == 1:
X = X[:, None]
Expand Down
34 changes: 17 additions & 17 deletions Orange/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,26 @@ def test_stats(self):

def test_stats_sparse(self):
X = csr_matrix(np.identity(5))
np.testing.assert_equal(stats(X), [[0, 1, .2, 0, 4, 1],
[0, 1, .2, 0, 4, 1],
[0, 1, .2, 0, 4, 1],
[0, 1, .2, 0, 4, 1],
[0, 1, .2, 0, 4, 1]])
np.testing.assert_equal(stats(X), [[0, 1, .2, 0, 0, 5],
[0, 1, .2, 0, 0, 5],
[0, 1, .2, 0, 0, 5],
[0, 1, .2, 0, 0, 5],
[0, 1, .2, 0, 0, 5]])

# assure last two columns have just zero elements
X = X[:3]
np.testing.assert_equal(stats(X), [[0, 1, 1/3, 0, 2, 1],
[0, 1, 1/3, 0, 2, 1],
[0, 1, 1/3, 0, 2, 1],
[0, 0, 0, 0, 3, 0],
[0, 0, 0, 0, 3, 0]])
np.testing.assert_equal(stats(X), [[0, 1, 1/3, 0, 0, 3],
[0, 1, 1/3, 0, 0, 3],
[0, 1, 1/3, 0, 0, 3],
[0, 0, 0, 0, 0, 3],
[0, 0, 0, 0, 0, 3]])

r = stats(X, compute_variance=True)
np.testing.assert_almost_equal(r, [[0, 1, 1/3, 2/9, 2, 1],
[0, 1, 1/3, 2/9, 2, 1],
[0, 1, 1/3, 2/9, 2, 1],
[0, 0, 0, 0, 3, 0],
[0, 0, 0, 0, 3, 0]])
np.testing.assert_almost_equal(r, [[0, 1, 1/3, 2/9, 0, 3],
[0, 1, 1/3, 2/9, 0, 3],
[0, 1, 1/3, 2/9, 0, 3],
[0, 0, 0, 0, 0, 3],
[0, 0, 0, 0, 0, 3]])

def test_stats_weights(self):
X = np.arange(4).reshape(2, 2).astype(float)
Expand All @@ -152,11 +152,11 @@ def test_stats_weights_sparse(self):
X = np.arange(4).reshape(2, 2).astype(float)
X = csr_matrix(X)
weights = np.array([1, 3])
np.testing.assert_equal(stats(X, weights), [[0, 2, 1.5, 0, 1, 1],
np.testing.assert_equal(stats(X, weights), [[0, 2, 1.5, 0, 0, 2],
[1, 3, 2.5, 0, 0, 2]])

np.testing.assert_equal(stats(X, weights, compute_variance=True),
[[0, 2, 1.5, 0.75, 1, 1],
[[0, 2, 1.5, 0.75, 0, 2],
[1, 3, 2.5, 0.75, 0, 2]])

def test_stats_non_numeric(self):
Expand Down

0 comments on commit 7af9325

Please sign in to comment.