Skip to content

Commit

Permalink
Cheat on CSC/CSR conversion for TruncScaler
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Apr 15, 2024
1 parent baac1fd commit b2d38ec
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
19 changes: 17 additions & 2 deletions inferelator_velocity/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,24 @@ def copy_count_layer(data, layer, counts_layer=None):
class TruncRobustScaler(RobustScaler):

def fit(self, X, y=None):
super().fit(X, y)

# Use StandardScaler to deal with sparse & dense easily
if isinstance(X, (sps.csr_matrix, sps.csc_array)):
# Use custom extractor to turn X into a CSC with no
# indices array; RobustScaler makes an undesirabe
# CSR->CSC conversion
from .sparse_math import sparse_csr_extract_columns
super().fit(
sparse_csr_extract_columns(X, fake_csc_matrix=True),
y
)
else:
super().fit(
X,
y
)

# Use StandardScaler to deal with sparse & dense
# There are C extensions for CSR variance without copy
_std_scale = StandardScaler(with_mean=False).fit(X)

_post_robust_var = _std_scale.var_ / (self.scale_ ** 2)
Expand Down
68 changes: 68 additions & 0 deletions inferelator_velocity/utils/sparse_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,42 @@ def sparse_normalize_total(sparse_array, target_sum=10_000):
)


def sparse_csr_extract_columns(
sparse_array,
fake_csc_matrix
):

col_indptr = _csr_to_csc_indptr(
sparse_array.indices,
sparse_array.shape[1]
)

print(col_indptr)

new_data = _csr_extract_columns(
sparse_array.data,
sparse_array.indices,
col_indptr
)

print(new_data)

if fake_csc_matrix:
arr = sps.csc_matrix(
sparse_array.shape,
dtype=sparse_array.dtype
)

arr.data = new_data
arr.indices = np.zeros((1,), dtype=col_indptr.dtype)
arr.indptr = col_indptr

return arr

else:
return new_data, col_indptr


@numba.njit(parallel=False)
def _mse_rowwise(
a_data,
Expand Down Expand Up @@ -231,6 +267,38 @@ def _csr_column_divide(data, indices, column_normalization_vec):
data[i] /= column_normalization_vec[idx]


def _csr_column_nnz(indices, n_col):

return np.bincount(indices, minlength=n_col)


def _csr_to_csc_indptr(indices, n_col):

output = np.zeros(n_col + 1, dtype=int)

np.cumsum(
_csr_column_nnz(indices, n_col),
out=output[1:]
)

return output


@numba.njit(parallel=False)
def _csr_extract_columns(data, col_indices, new_col_indptr):

output_data = np.zeros_like(data)
col_indptr_used = np.zeros_like(new_col_indptr)

for i in range(data.shape[0]):
_col = col_indices[i]
_new_pos = new_col_indptr[_col] + col_indptr_used[_col]
output_data[_new_pos] = data[i]
col_indptr_used[_col] += 1

return output_data


def is_csr(x):
return sps.isspmatrix_csr(x) or isinstance(x, sps.csr_array)

Expand Down

0 comments on commit b2d38ec

Please sign in to comment.