Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:HiDiHlabs/sainsc into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasmueboe committed Oct 24, 2024
2 parents 398ad8c + 7c39875 commit 289b49d
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 16 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ polars-arrow = { version = "0.41" }
pyo3 = { version = "0.21", features = ["extension-module"] }
pyo3-polars = { version = "0.15" }
rayon = { version = "1.8" }
serde_json = { version = "1" }
sprs = { version = "= 0.11.1", features = ["serde"] }
zarrs = { version = "0.16.4", features = ["ndarray", "gzip"] }
5 changes: 5 additions & 0 deletions sainsc/_utils_rust.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Self

import numpy as np
Expand Down Expand Up @@ -42,10 +43,12 @@ def coordinate_as_string(
def cosinef32_and_celltypei8(
counts: GridCounts,
genes: list[str],
celltypes: list[str],
signatures: NDArray[np.float32],
kernel: NDArray[np.float32],
*,
log: bool = False,
zarr_path: Path | None = None,
chunk_size: tuple[int, int] = (500, 500),
n_threads: int | None = None,
) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.int8]]:
Expand All @@ -57,10 +60,12 @@ def cosinef32_and_celltypei8(
def cosinef32_and_celltypei16(
counts: GridCounts,
genes: list[str],
celltypes: list[str],
signatures: NDArray[np.float32],
kernel: NDArray[np.float32],
*,
log: bool = False,
zarr_path: Path | None = None,
chunk_size: tuple[int, int] = (500, 500),
n_threads: int | None = None,
) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.int16]]:
Expand Down
24 changes: 23 additions & 1 deletion sainsc/lazykde/_LazyKDE.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterable
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self

import matplotlib.pyplot as plt
Expand All @@ -19,7 +20,7 @@
from scipy.sparse import coo_array, csc_array, csr_array
from skimage.feature import peak_local_max

from .._typealias import _Cmap, _Csx, _CsxArray, _Local_Max, _RangeTuple2D
from .._typealias import _Cmap, _Csx, _CsxArray, _Local_Max, _PathLike, _RangeTuple2D
from .._utils import _raise_module_load_error, _validate_n_threads, validate_threads
from .._utils_rust import (
GridCounts,
Expand Down Expand Up @@ -542,6 +543,7 @@ def assign_celltype(
signatures: pd.DataFrame,
*,
log: bool = False,
zarr_path: _PathLike | None = None,
chunk: tuple[int, int] = (500, 500),
):
"""
Expand All @@ -557,6 +559,9 @@ def assign_celltype(
log : bool
Whether to log transform the KDE when calculating the cosine similarity.
This is useful if the gene signatures are derived from log-transformed data.
zarr_path : os.PathLike, str, or None
If not `None` the cosine similarities for all cell types will be written to
the specified path as zarr storage.
chunk : tuple[int, int]
Size of the chunks for processing. Larger chunks require more memory but
have less duplicated computation.
Expand All @@ -569,8 +574,13 @@ def assign_celltype(
If `self.kernel` is not set.
ValueError
If `chunk` is smaller than the shape of `self.kernel`.
ValueError
If `zarr_path` is not None and the celltype names contain
illegal characters for file names.
"""

ILLEGAL_CHARS = ["/", "\\"]

if not all(signatures.index.isin(self.genes)):
raise ValueError(
"Not all genes in the gene signature are part of this KDE."
Expand All @@ -587,6 +597,16 @@ def assign_celltype(
celltypes = signatures.columns.tolist()
ct_dtype = _get_cell_dtype(len(celltypes))

zarr_path = None if zarr_path is None else Path(zarr_path)

if zarr_path is not None and any(
char in ct for char in ILLEGAL_CHARS for ct in celltypes
):
raise ValueError(
"Celltype names contain at least one of the illegal characters: "
f"{ILLEGAL_CHARS}"
)

# scale signatures to unit norm
signatures_mat = signatures.to_numpy()
signatures_mat = (
Expand All @@ -600,9 +620,11 @@ def assign_celltype(
self._cosine_similarity, self._assignment_score, self._celltype_map = fn(
self.counts,
genes,
celltypes,
signatures_mat,
self.kernel,
log=log,
zarr_path=zarr_path,
chunk_size=chunk,
n_threads=self.n_threads,
)
Expand Down
3 changes: 3 additions & 0 deletions sainsc/lazykde/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
from pathlib import Path
from typing import Protocol, TypeVar

import numpy as np
Expand Down Expand Up @@ -71,10 +72,12 @@ def __call__(
self,
counts: GridCounts,
genes: list[str],
celltypes: list[str],
signatures: NDArray[np.float32],
kernel: NDArray[np.float32],
*,
log: bool = ...,
zarr_path: Path | None = None,
chunk_size: tuple[int, int] = ...,
n_threads: int | None = ...,
) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.signedinteger]]: ...
73 changes: 58 additions & 15 deletions src/cosine.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::cosine_zarr::{initialize_cosine_zarrstore, write_cosine_to_zarr, ZarrChunkInfo};
use crate::gridcounts::GridCounts;
use crate::sparsekde::sparse_kde_csx_;
use crate::utils::create_pool;
Expand All @@ -12,20 +13,23 @@ use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use pyo3::{exceptions::PyValueError, prelude::*};
use rayon::prelude::*;
use sprs::{CompressedStorage::CSR, CsMatI, CsMatViewI, SpIndex};
use std::{cmp::min, error::Error, ops::Range};
use std::{cmp::min, error::Error, ops::Range, path::PathBuf};
use zarrs::array::Element;

macro_rules! build_cos_ct_fn {
($name:tt, $t_cos:ty, $t_ct:ty) => {
#[pyfunction]
#[pyo3(signature = (counts, genes, signatures, kernel, *, log=false, chunk_size=(500, 500), n_threads=None))]
#[pyo3(signature = (counts, genes, celltypes, signatures, kernel, *, log=false, zarr_path=None, chunk_size=(500, 500), n_threads=None))]
/// calculate cosine similarity and assign celltype
pub fn $name<'py>(
py: Python<'py>,
counts: &mut GridCounts,
genes: Vec<String>,
celltypes: Vec<String>,
signatures: PyReadonlyArray2<'py, $t_cos>,
kernel: PyReadonlyArray2<'py, $t_cos>,
log: bool,
zarr_path:Option<PathBuf>,
chunk_size: (usize, usize),
n_threads: Option<usize>,
) -> PyResult<(
Expand All @@ -46,10 +50,12 @@ macro_rules! build_cos_ct_fn {

let cos_ct = chunk_and_calculate_cosine(
&gene_counts,
celltypes,
signatures.as_array(),
kernel.as_array(),
counts.shape,
log,
zarr_path,
chunk_size,
n_threads
);
Expand All @@ -71,17 +77,19 @@ build_cos_ct_fn!(cosinef32_and_celltypei16, f32, i16);

fn chunk_and_calculate_cosine<C, I, F, U>(
counts: &[CsMatViewI<C, I>],
celltypes: Vec<String>,
signatures: ArrayView2<F>,
kernel: ArrayView2<F>,
shape: (usize, usize),
log: bool,
zarr_path: Option<PathBuf>,
chunk_size: (usize, usize),
n_threads: Option<usize>,
) -> Result<(Array2<F>, Array2<F>, Array2<U>), Box<dyn Error>>
) -> Result<(Array2<F>, Array2<F>, Array2<U>), Box<dyn Error + Send + Sync>>
where
C: NumCast + Copy + Sync + Send + Default,
I: SpIndex + Signed + Sync + Send,
F: NdFloat,
F: NdFloat + Element,
U: PrimInt + Signed + Sync + Send,
Slice: From<Range<I>>,
{
Expand Down Expand Up @@ -115,7 +123,16 @@ where
}
});

let ((cosine, score), celltype): ((Vec<_>, Vec<_>), Vec<_>) = pool.install(|| {
// init zarr store for celltypes with chunksize and all zero arrays
let zarr_store = match zarr_path
.map(|path| initialize_cosine_zarrstore(path, &celltypes, shape, chunk_size))
{
Some(Err(e)) => return Err(e),
Some(Ok(store)) => Some(store),
None => None,
};

let celltyping_results = pool.install(|| {
// generate all chunk indices
let chunk_indices: Vec<_> = (0..m).cartesian_product(0..n).collect();

Expand All @@ -125,18 +142,28 @@ where
.map(|idx| {
let (chunk, unpad) = get_chunk(counts, idx, shape, chunk_size, pad);

let zarr_info = zarr_store.clone().map(|store| ZarrChunkInfo {
store,
celltypes: { celltypes.clone() },
chunk_idx: vec![idx.0 as u64, idx.1 as u64],
});

cosine_and_celltype_(
chunk,
signatures,
&signature_similarity_correction,
kernel,
unpad,
log,
zarr_info,
)
})
.unzip()
.collect::<Vec<_>>()
});

let ((cosine, score), celltype): ((Vec<_>, Vec<_>), Vec<_>) =
itertools::process_results(celltyping_results, |iter| iter.unzip())?;

// concatenate all chunks back to original shape
Ok((
concat_2d(&cosine, n)?,
Expand Down Expand Up @@ -207,10 +234,11 @@ fn cosine_and_celltype_<C, I, F, U>(
kernel: ArrayView2<F>,
unpad: (Range<usize>, Range<usize>),
log: bool,
) -> ((Array2<F>, Array2<F>), Array2<U>)
zarr_info: Option<ZarrChunkInfo>,
) -> Result<((Array2<F>, Array2<F>), Array2<U>), Box<dyn Error + Send + Sync>>
where
C: NumCast + Copy,
F: NdFloat,
F: NdFloat + Element,
U: PrimInt + Signed,
I: SpIndex + Signed,
Slice: From<Range<I>>,
Expand All @@ -225,10 +253,10 @@ where
// fastpath if all csx are empty
None => {
let shape = (unpad_r.end - unpad_r.start, unpad_c.end - unpad_c.start);
(
Ok((
(Array2::zeros(shape), Array2::zeros(shape)),
Array2::from_elem(shape, -one::<U>()),
)
))
}
Some((csx, weights)) => {
let shape = csx.shape();
Expand Down Expand Up @@ -262,8 +290,24 @@ where
.filter(|(_, &w)| w != zero::<F>())
.for_each(|(mut cos, &w)| cos += &kde_unpadded.map(|&x| x * w));
}
// TODO: write to zarr
get_max_cosine_and_celltype(cosine, kde_norm, pairwise_correction)

kde_norm.mapv_inplace(F::sqrt);

if let Some(zarr_info) = zarr_info {
write_cosine_to_zarr(
zarr_info.store,
&cosine,
&kde_norm,
&zarr_info.celltypes,
&zarr_info.chunk_idx,
)?
};

Ok(get_max_cosine_and_celltype(
cosine,
kde_norm,
pairwise_correction,
))
}
}
}
Expand Down Expand Up @@ -291,9 +335,8 @@ where
*ct = -one::<I>();
*s = zero();
} else {
let norm_sqrt = norm.sqrt();
*cos /= norm_sqrt;
*s /= norm_sqrt;
*cos /= norm;
*s /= norm;
};
});

Expand Down
Loading

0 comments on commit 289b49d

Please sign in to comment.