Skip to content

Commit

Permalink
Inline _read and AxisQueryResult to to_anndata (#3552)
Browse files Browse the repository at this point in the history
The `_read` function and `AxisQueryResult` appear to exist to prevent
future code duplication. However, this adds indirection and line noise
to the current codebase. We can add these back in the future if/when
they are needed.
  • Loading branch information
jp-dark authored Jan 10, 2025
1 parent 4d3e4b6 commit 1c8b1bd
Showing 1 changed file with 98 additions and 151 deletions.
249 changes: 98 additions & 151 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,40 +154,6 @@ def _to_numpy(it: Numpyable) -> npt.NDArray[np.int64]:
return it.to_numpy()


@attrs.define(frozen=True)
class AxisQueryResult:
"""The result of running :meth:`ExperimentAxisQuery.read`. Private."""

obs: pd.DataFrame
"""Experiment.obs query slice, as a pandas DataFrame"""
var: pd.DataFrame
"""Experiment.ms[...].var query slice, as a pandas DataFrame"""
X: sp.csr_matrix
"""Experiment.ms[...].X[...] query slice, as a SciPy sparse.csr_matrix """
X_layers: Dict[str, sp.csr_matrix] = attrs.field(factory=dict)
"""Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
obsm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
"""Experiment.obsm query slice, as a numpy ndarray"""
obsp: Dict[str, sp.csr_matrix] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as SciPy sparse.csr_matrix(s)"""
varm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
"""Experiment.varm query slice, as a numpy ndarray"""
varp: Dict[str, sp.csr_matrix] = attrs.field(factory=dict)
"""Experiment.varp query slice, as SciPy sparse.csr_matrix(s)"""

def to_anndata(self) -> AnnData:
return AnnData(
X=self.X,
obs=self.obs,
var=self.var,
obsm=(self.obsm or None),
obsp=(self.obsp or None),
varm=(self.varm or None),
varp=(self.varp or None),
layers=(self.X_layers or None),
)


class ExperimentAxisQuery(query.ExperimentAxisQuery):
"""Axis-based query against a SOMA Experiment.
Expand Down Expand Up @@ -467,26 +433,108 @@ def to_anndata(
varp_layers: Sequence[str] = (),
drop_levels: bool = False,
) -> AnnData:
ad = self._read(
X_name,
column_names=column_names or AxisColumnNames(obs=None, var=None),
X_layers=X_layers,
obsm_layers=obsm_layers,
obsp_layers=obsp_layers,
varm_layers=varm_layers,
varp_layers=varp_layers,
).to_anndata()
"""Exports the query to an in-memory ``AnnData`` object.
Args:
X_name:
The X layer to read and return in the ``X`` slot.
column_names:
The columns in the ``var`` and ``obs`` dataframes to read.
X_layers:
Additional X layers to read and return in the ``layers`` slot.
obsm_layers:
Additional obsm layers to read and return in the obsm slot.
obsp_layers:
Additional obsp layers to read and return in the obsp slot.
varm_layers:
Additional varm layers to read and return in the varm slot.
varp_layers:
Additional varp layers to read and return in the varp slot.
drop_levels:
If true, drop unused categories from the ``obs`` and ``var`` dataframes.
Defaults to ``False``.
"""

if column_names is None:
column_names = AxisColumnNames(obs=None, var=None)

tp = self._threadpool
x_collection = self._ms.X
all_x_names = [X_name] + list(X_layers)
all_x_arrays: Dict[str, SparseNDArray] = {}
for _xname in all_x_names:
if not isinstance(_xname, str) or not _xname:
raise ValueError("X layer names must be specified as a string.")
if _xname not in x_collection:
raise ValueError("Unknown X layer name")
x_array = x_collection[_xname]
if not isinstance(x_array, SparseNDArray):
raise NotImplementedError("Dense array unsupported")
all_x_arrays[_xname] = x_array

obs_table, var_table = tp.map(
self._read_axis_dataframe,
(AxisName.OBS, AxisName.VAR),
(column_names, column_names),
)
obs_joinids = self.obs_joinids()
var_joinids = self.var_joinids()

x_matrices = {
_xname: tp.submit(
_read_as_csr,
layer,
obs_joinids,
var_joinids,
self._indexer.by_obs,
self._indexer.by_var,
)
for _xname, layer in all_x_arrays.items()
}
x_future = x_matrices.pop(X_name)

obsm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key)
for key in obsm_layers
}
varm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key)
for key in varm_layers
}
obsp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key)
for key in obsp_layers
}
varp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key)
for key in varp_layers
}

obs = obs_table.to_pandas()
obs.index = obs.index.astype(str)

var = var_table.to_pandas()
var.index = var.index.astype(str)

# Drop unused categories on axis dataframes if requested
if drop_levels:
for name in ad.obs:
if ad.obs[name].dtype.name == "category":
ad.obs[name] = ad.obs[name].cat.remove_unused_categories()
for name in ad.var:
if ad.var[name].dtype.name == "category":
ad.var[name] = ad.var[name].cat.remove_unused_categories()
for name in obs:
if obs[name].dtype.name == "category":
obs[name] = obs[name].cat.remove_unused_categories()
for name in var:
if var[name].dtype.name == "category":
var[name] = var[name].cat.remove_unused_categories()

return ad
return AnnData(
X=x_future.result(),
obs=obs,
var=var,
obsm=(_resolve_futures(obsm_future) or None),
obsp=(_resolve_futures(obsp_future) or None),
varm=(_resolve_futures(varm_future) or None),
varp=(_resolve_futures(varp_future) or None),
layers=(_resolve_futures(x_matrices) or None),
)

def to_spatialdata( # type: ignore[no-untyped-def]
self,
Expand Down Expand Up @@ -580,107 +628,6 @@ def __exit__(self, *_: Any) -> None:

# Internals

def _read(
self,
X_name: str,
*,
column_names: AxisColumnNames,
X_layers: Sequence[str],
obsm_layers: Sequence[str] = (),
obsp_layers: Sequence[str] = (),
varm_layers: Sequence[str] = (),
varp_layers: Sequence[str] = (),
) -> AxisQueryResult:
"""Reads the entire query result in memory.
This is a low-level routine intended to be used by loaders for other
in-core formats, such as AnnData, which can be created from the
resulting objects.
Args:
X_name: The X layer to read and return in the ``X`` slot.
column_names: The columns in the ``var`` and ``obs`` dataframes
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_layers:
Additional obsm layers to read and return in the obsm slot.
obsp_layers:
Additional obsp layers to read and return in the obsp slot.
varm_layers:
Additional varm layers to read and return in the varm slot.
varp_layers:
Additional varp layers to read and return in the varp slot.
"""
tp = self._threadpool
x_collection = self._ms.X
all_x_names = [X_name] + list(X_layers)
all_x_arrays: Dict[str, SparseNDArray] = {}
for _xname in all_x_names:
if not isinstance(_xname, str) or not _xname:
raise ValueError("X layer names must be specified as a string.")
if _xname not in x_collection:
raise ValueError("Unknown X layer name")
x_array = x_collection[_xname]
if not isinstance(x_array, SparseNDArray):
raise NotImplementedError("Dense array unsupported")
all_x_arrays[_xname] = x_array

obs_table, var_table = tp.map(
self._read_axis_dataframe,
(AxisName.OBS, AxisName.VAR),
(column_names, column_names),
)
obs_joinids = self.obs_joinids()
var_joinids = self.var_joinids()

x_matrices = {
_xname: tp.submit(
_read_as_csr,
layer,
obs_joinids,
var_joinids,
self._indexer.by_obs,
self._indexer.by_var,
)
for _xname, layer in all_x_arrays.items()
}
x_future = x_matrices.pop(X_name)

obsm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key)
for key in obsm_layers
}
varm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key)
for key in varm_layers
}
obsp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key)
for key in obsp_layers
}
varp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key)
for key in varp_layers
}

obs = obs_table.to_pandas()
obs.index = obs.index.astype(str)

var = var_table.to_pandas()
var.index = var.index.astype(str)

return AxisQueryResult(
obs=obs,
var=var,
X=x_future.result(),
obsm=_resolve_futures(obsm_future),
obsp=_resolve_futures(obsp_future),
varm=_resolve_futures(varm_future),
varp=_resolve_futures(varp_future),
X_layers=_resolve_futures(x_matrices),
)

def _read_axis_dataframe(
self,
axis: AxisName,
Expand Down

0 comments on commit 1c8b1bd

Please sign in to comment.