From e2bc73e78d2497ed38fd7462f1da2b3967a6847a Mon Sep 17 00:00:00 2001 From: Julia Dark Date: Fri, 10 Jan 2025 11:51:50 -0500 Subject: [PATCH] Inline `_read` and `AxisQueryResult` to `to_anndata` The `_read` function and `AxisQueryResult` appear to exist to prevent future code duplication. However, this adds indirection and line noise to the current codebase. We can add these back in the future if/when they are needed. --- apis/python/src/tiledbsoma/_query.py | 249 +++++++++++---------------- 1 file changed, 98 insertions(+), 151 deletions(-) diff --git a/apis/python/src/tiledbsoma/_query.py b/apis/python/src/tiledbsoma/_query.py index e7270172b0..37ac9aacfb 100644 --- a/apis/python/src/tiledbsoma/_query.py +++ b/apis/python/src/tiledbsoma/_query.py @@ -154,40 +154,6 @@ def _to_numpy(it: Numpyable) -> npt.NDArray[np.int64]: return it.to_numpy() -@attrs.define(frozen=True) -class AxisQueryResult: - """The result of running :meth:`ExperimentAxisQuery.read`. Private.""" - - obs: pd.DataFrame - """Experiment.obs query slice, as a pandas DataFrame""" - var: pd.DataFrame - """Experiment.ms[...].var query slice, as a pandas DataFrame""" - X: sp.csr_matrix - """Experiment.ms[...].X[...] query slice, as a SciPy sparse.csr_matrix """ - X_layers: Dict[str, sp.csr_matrix] = attrs.field(factory=dict) - """Any additional X layers requested, as SciPy sparse.csr_matrix(s)""" - obsm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict) - """Experiment.obsm query slice, as a numpy ndarray""" - obsp: Dict[str, sp.csr_matrix] = attrs.field(factory=dict) - """Experiment.obsp query slice, as SciPy sparse.csr_matrix(s)""" - varm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict) - """Experiment.varm query slice, as a numpy ndarray""" - varp: Dict[str, sp.csr_matrix] = attrs.field(factory=dict) - """Experiment.varp query slice, as SciPy sparse.csr_matrix(s)""" - - def to_anndata(self) -> AnnData: - return AnnData( - X=self.X, - obs=self.obs, - var=self.var, - obsm=(self.obsm or None), - obsp=(self.obsp or None), - varm=(self.varm or None), - varp=(self.varp or None), - layers=(self.X_layers or None), - ) - - class ExperimentAxisQuery(query.ExperimentAxisQuery): """Axis-based query against a SOMA Experiment. @@ -467,26 +433,108 @@ def to_anndata( varp_layers: Sequence[str] = (), drop_levels: bool = False, ) -> AnnData: - ad = self._read( - X_name, - column_names=column_names or AxisColumnNames(obs=None, var=None), - X_layers=X_layers, - obsm_layers=obsm_layers, - obsp_layers=obsp_layers, - varm_layers=varm_layers, - varp_layers=varp_layers, - ).to_anndata() + """Exports the query to an in-memory ``AnnData`` object. + + Args: + X_name: + The X layer to read and return in the ``X`` slot. + column_names: + The columns in the ``var`` and ``obs`` dataframes to read. + X_layers: + Additional X layers to read and return in the ``layers`` slot. + obsm_layers: + Additional obsm layers to read and return in the obsm slot. + obsp_layers: + Additional obsp layers to read and return in the obsp slot. + varm_layers: + Additional varm layers to read and return in the varm slot. + varp_layers: + Additional varp layers to read and return in the varp slot. + drop_levels: + If true, drop unused categories from the ``obs`` and ``var`` dataframes. + Defaults to ``False``. + """ + + if column_names is None: + column_names = AxisColumnNames(obs=None, var=None) + + tp = self._threadpool + x_collection = self._ms.X + all_x_names = [X_name] + list(X_layers) + all_x_arrays: Dict[str, SparseNDArray] = {} + for _xname in all_x_names: + if not isinstance(_xname, str) or not _xname: + raise ValueError("X layer names must be specified as a string.") + if _xname not in x_collection: + raise ValueError("Unknown X layer name") + x_array = x_collection[_xname] + if not isinstance(x_array, SparseNDArray): + raise NotImplementedError("Dense array unsupported") + all_x_arrays[_xname] = x_array + + obs_table, var_table = tp.map( + self._read_axis_dataframe, + (AxisName.OBS, AxisName.VAR), + (column_names, column_names), + ) + obs_joinids = self.obs_joinids() + var_joinids = self.var_joinids() + + x_matrices = { + _xname: tp.submit( + _read_as_csr, + layer, + obs_joinids, + var_joinids, + self._indexer.by_obs, + self._indexer.by_var, + ) + for _xname, layer in all_x_arrays.items() + } + x_future = x_matrices.pop(X_name) + + obsm_future = { + key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key) + for key in obsm_layers + } + varm_future = { + key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key) + for key in varm_layers + } + obsp_future = { + key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key) + for key in obsp_layers + } + varp_future = { + key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key) + for key in varp_layers + } + + obs = obs_table.to_pandas() + obs.index = obs.index.astype(str) + + var = var_table.to_pandas() + var.index = var.index.astype(str) # Drop unused categories on axis dataframes if requested if drop_levels: - for name in ad.obs: - if ad.obs[name].dtype.name == "category": - ad.obs[name] = ad.obs[name].cat.remove_unused_categories() - for name in ad.var: - if ad.var[name].dtype.name == "category": - ad.var[name] = ad.var[name].cat.remove_unused_categories() + for name in obs: + if obs[name].dtype.name == "category": + obs[name] = obs[name].cat.remove_unused_categories() + for name in var: + if var[name].dtype.name == "category": + var[name] = var[name].cat.remove_unused_categories() - return ad + return AnnData( + X=x_future.result(), + obs=obs, + var=var, + obsm=(_resolve_futures(obsm_future) or None), + obsp=(_resolve_futures(obsp_future) or None), + varm=(_resolve_futures(varm_future) or None), + varp=(_resolve_futures(varp_future) or None), + layers=(_resolve_futures(x_matrices) or None), + ) def to_spatialdata( # type: ignore[no-untyped-def] self, @@ -580,107 +628,6 @@ def __exit__(self, *_: Any) -> None: # Internals - def _read( - self, - X_name: str, - *, - column_names: AxisColumnNames, - X_layers: Sequence[str], - obsm_layers: Sequence[str] = (), - obsp_layers: Sequence[str] = (), - varm_layers: Sequence[str] = (), - varp_layers: Sequence[str] = (), - ) -> AxisQueryResult: - """Reads the entire query result in memory. - - This is a low-level routine intended to be used by loaders for other - in-core formats, such as AnnData, which can be created from the - resulting objects. - - Args: - X_name: The X layer to read and return in the ``X`` slot. - column_names: The columns in the ``var`` and ``obs`` dataframes - to read. - X_layers: Additional X layers to read and return - in the ``layers`` slot. - obsm_layers: - Additional obsm layers to read and return in the obsm slot. - obsp_layers: - Additional obsp layers to read and return in the obsp slot. - varm_layers: - Additional varm layers to read and return in the varm slot. - varp_layers: - Additional varp layers to read and return in the varp slot. - """ - tp = self._threadpool - x_collection = self._ms.X - all_x_names = [X_name] + list(X_layers) - all_x_arrays: Dict[str, SparseNDArray] = {} - for _xname in all_x_names: - if not isinstance(_xname, str) or not _xname: - raise ValueError("X layer names must be specified as a string.") - if _xname not in x_collection: - raise ValueError("Unknown X layer name") - x_array = x_collection[_xname] - if not isinstance(x_array, SparseNDArray): - raise NotImplementedError("Dense array unsupported") - all_x_arrays[_xname] = x_array - - obs_table, var_table = tp.map( - self._read_axis_dataframe, - (AxisName.OBS, AxisName.VAR), - (column_names, column_names), - ) - obs_joinids = self.obs_joinids() - var_joinids = self.var_joinids() - - x_matrices = { - _xname: tp.submit( - _read_as_csr, - layer, - obs_joinids, - var_joinids, - self._indexer.by_obs, - self._indexer.by_var, - ) - for _xname, layer in all_x_arrays.items() - } - x_future = x_matrices.pop(X_name) - - obsm_future = { - key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key) - for key in obsm_layers - } - varm_future = { - key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key) - for key in varm_layers - } - obsp_future = { - key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key) - for key in obsp_layers - } - varp_future = { - key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key) - for key in varp_layers - } - - obs = obs_table.to_pandas() - obs.index = obs.index.astype(str) - - var = var_table.to_pandas() - var.index = var.index.astype(str) - - return AxisQueryResult( - obs=obs, - var=var, - X=x_future.result(), - obsm=_resolve_futures(obsm_future), - obsp=_resolve_futures(obsp_future), - varm=_resolve_futures(varm_future), - varp=_resolve_futures(varp_future), - X_layers=_resolve_futures(x_matrices), - ) - def _read_axis_dataframe( self, axis: AxisName,