Skip to content

Commit

Permalink
Add more in-place sparse ops
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Apr 12, 2024
1 parent 72bb236 commit ca9fdfd
Show file tree
Hide file tree
Showing 6 changed files with 475 additions and 151 deletions.
125 changes: 124 additions & 1 deletion inferelator_velocity/tests/test_mcv.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
import unittest
import numpy as np
import numpy.testing as npt
import scipy.sparse as sps
import anndata as ad

from inferelator_velocity.utils.mcv import mcv_pcs
from inferelator_velocity.utils import TruncRobustScaler
from inferelator_velocity.utils.mcv import (
mcv_pcs,
mcv_mse,
standardize_data
)
from inferelator_velocity.utils.misc import (
_normalize_for_pca
)

from ._stubs import (
COUNTS
)


def _safe_sum(x, axis):

sums = x.sum(axis)

try:
sums = sums.A1
except AttributeError:
pass

return sums


def _safe_dense(x):

try:
return x.A
except AttributeError:
return x


class TestMCV(unittest.TestCase):

def test_sparse_log(self):
Expand All @@ -29,3 +59,96 @@ def test_sparse_log_scale(self):
),
0
)


class TestMCVMetrics(unittest.TestCase):

def testMSErow(self):
data = sps.csr_matrix(COUNTS)

mse = mcv_mse(
data,
data @ np.zeros((10, 10)),
np.eye(10),
by_row=True
)

npt.assert_almost_equal(
data.power(2).sum(axis=1).A1 / 10,
mse
)


class TestMCVStandardization(unittest.TestCase):

tol = 6

def setUp(self):
super().setUp()
self.data = ad.AnnData(sps.csr_matrix(COUNTS))

def test_depth(self):

_normalize_for_pca(self.data, target_sum=100, log=False, scale=False)

rowsums = _safe_sum(self.data.X, 1)

npt.assert_almost_equal(
rowsums,
np.full_like(rowsums, 100.),
decimal=self.tol
)

def test_depth_log(self):

_normalize_for_pca(self.data, target_sum=100, log=True, scale=False)

rowsums = _safe_sum(self.data.X, 1)

npt.assert_almost_equal(
rowsums,
np.log1p(100 * COUNTS / np.sum(COUNTS, axis=1)[:, None]).sum(1),
decimal=self.tol
)

def test_depth_scale(self):

_normalize_for_pca(self.data, target_sum=100, log=False, scale=True)

npt.assert_almost_equal(
_safe_dense(self.data.X),
TruncRobustScaler(with_centering=False).fit_transform(
100 * COUNTS / np.sum(COUNTS, axis=1)[:, None]
),
decimal=self.tol
)

def test_depth_log_scale(self):

_normalize_for_pca(self.data, target_sum=100, log=True, scale=True)

npt.assert_almost_equal(
_safe_dense(self.data.X),
TruncRobustScaler(with_centering=False).fit_transform(
np.log1p(100 * COUNTS / np.sum(COUNTS, axis=1)[:, None])
),
decimal=self.tol
)


class TestMCVStandardizationDense(TestMCVStandardization):

tol = 4

def setUp(self):
super().setUp()
self.data = ad.AnnData(COUNTS.copy())


class TestMCVStandardizationCSC(TestMCVStandardization):

tol = 4

def setUp(self):
super().setUp()
self.data = ad.AnnData(sps.csc_matrix(COUNTS))
153 changes: 47 additions & 106 deletions inferelator_velocity/utils/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import numba
import scipy.sparse as sps
from .misc import (
make_vector_2D,
Expand Down Expand Up @@ -140,7 +139,7 @@ def _calc_se(x, y, slope):
return se_y / mse_x


def _log_loss(x, y):
def _log_loss(x, y, axis=1):

if y is None:
raise ValueError(
Expand Down Expand Up @@ -168,12 +167,12 @@ def _log_loss(x, y):
x,
np.log(y)
)
err = err.sum(axis=1)
err = err.sum(axis=axis)
err *= -1
return err


def _mse(x, y):
def _mse(x, y, axis=1):

if y is not None:
ssr = x - y
Expand All @@ -188,24 +187,25 @@ def _mse(x, y):
else:
ssr **= 2

return ssr.sum(axis=1)
return ssr.sum(axis=axis)


def _mae(x, y):
def _mae(x, y, axis=1):

if y is not None:
ssr = x - y
else:
ssr = x

return ssr.sum(axis=1)
return ssr.sum(axis=axis)


def pairwise_metric(
x,
y,
metric='mse',
by_row=False,
axis=1,
**kwargs
):
"""
Expand Down Expand Up @@ -233,6 +233,7 @@ def pairwise_metric(
loss = metric(
x,
y,
axis=axis,
**kwargs
)

Expand All @@ -241,7 +242,7 @@ def pairwise_metric(
except AttributeError:
pass

loss = loss / x.shape[1]
loss = loss / x.shape[axis]

if by_row:
return loss
Expand Down Expand Up @@ -343,77 +344,50 @@ def coefficient_of_variation(
)


@numba.njit(parallel=False)
def _csr_row_divide(
a_data,
a_indptr,
row_vec
def mcv_mse(
x,
pc,
rotation,
by_row=False,
axis=1,
**metric_kwargs
):

n_row = row_vec.shape[0]

for i in numba.prange(n_row):
a_data[a_indptr[i]:a_indptr[i + 1]] /= row_vec[i]

if sps.issparse(x):

@numba.njit(parallel=False)
def _mse_rowwise(
a_data,
a_indices,
a_indptr,
b_pcs,
b_rotation
):
from .sparse_math import mcv_mse_sparse

n_row = b_pcs.shape[0]
return mcv_mse_sparse(
x,
pc,
rotation,
by_row=by_row,
axis=axis,
**metric_kwargs
)

output = np.zeros(n_row, dtype=float)
else:

for i in numba.prange(n_row):
return pairwise_metric(
x,
pc @ rotation,
metric='mse',
by_row=by_row,
axis=axis,
**metric_kwargs
)

_idx_a = a_indices[a_indptr[i]:a_indptr[i + 1]]
_nnz_a = _idx_a.shape[0]

row = b_pcs[i, :] @ b_rotation
def array_sum(array, axis=None, squared=False):

if _nnz_a == 0:
continue
if not is_csr(array):

if squared and not sps.issparse(array):
_sums = (array ** 2).sum(axis=axis)
elif squared:
_sums = array.power(2).sum(axis=axis)
else:

row[_idx_a] -= a_data[a_indptr[i]:a_indptr[i + 1]]

output[i] = np.mean(row ** 2)

return output


@numba.njit(parallel=False)
def _sum_columns(data, indices, n_col):

output = np.zeros(n_col, dtype=data.dtype)

for i in numba.prange(data.shape[0]):
output[indices[i]] += data[i]

return output


@numba.njit(parallel=False)
def _sum_rows(data, indptr):

output = np.zeros(indptr.shape[0] - 1, dtype=data.dtype)

for i in numba.prange(output.shape[0]):
output[i] = np.sum(data[indptr[i]:indptr[i + 1]])

return output


def array_sum(array, axis=None):

if not is_csr(array):
_sums = array.sum(axis=axis)
_sums = array.sum(axis=axis)
try:
_sums = _sums.A1
except AttributeError:
Expand All @@ -423,44 +397,11 @@ def array_sum(array, axis=None):
if axis is None:
return np.sum(array.data)

elif axis == 0:
return _sum_columns(
array.data,
array.indices,
array.shape[1]
)

elif axis == 1:
return _sum_rows(
array.data,
array.indptr
)


def mcv_mse(x, pc, rotation, by_row=False, **metric_kwargs):

if sps.issparse(x):

y = _mse_rowwise(
x.data,
x.indices,
x.indptr,
np.ascontiguousarray(pc),
np.ascontiguousarray(rotation, dtype=pc.dtype)
)

if by_row:
return y

else:
return np.mean(y)

else:
from .sparse_math import sparse_sum

return pairwise_metric(
x,
pc @ rotation,
metric='mse',
by_row=by_row,
**metric_kwargs
return sparse_sum(
array,
axis=axis,
squared=squared
)
Loading

0 comments on commit ca9fdfd

Please sign in to comment.