Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] option to return csr tensors in datapipe #1062

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import timedelta
from math import ceil
from time import time
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence, Tuple

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(
batch_size: int,
encoders: Dict[str, LabelEncoder],
stats: Stats,
return_sparse_X: bool,
X_format: Literal["dense", "csr", "coo"],
use_eager_fetch: bool,
shuffle_rng: Optional[Generator] = None,
) -> None:
Expand All @@ -238,7 +238,7 @@ def __init__(
self.soma_chunk = None
self.var_joinids = var_joinids
self.batch_size = batch_size
self.return_sparse_X = return_sparse_X
self.X_format = X_format
self.encoders = encoders
self.stats = stats
self.max_process_mem_usage_bytes = 0
Expand Down Expand Up @@ -272,8 +272,15 @@ def __next__(self) -> ObsAndXDatum:
# `to_numpy()` avoids copying the numpy array data
obs_tensor = torch.from_numpy(obs_encoded.to_numpy())

if not self.return_sparse_X:
if self.X_format == "dense":
X_tensor = torch.from_numpy(X.todense())
elif self.X_format == "csr":
X_tensor = torch.sparse_csr_tensor(
crow_indices=torch.as_tensor(X.indptr),
col_indices=torch.as_tensor(X.indices),
values=torch.as_tensor(X.data),
size=X.shape,
)
else:
coo = X.tocoo()

Expand Down Expand Up @@ -350,7 +357,7 @@ class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ig
[2416, 0, 4],
[2417, 0, 3]], dtype=torch.int64))

The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or sparse
The ``X_format`` parameter controls whether the ``X`` data is returned as a dense or sparse
:class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, this will reduce memory usage.

The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` Tensor. The first
Expand Down Expand Up @@ -390,7 +397,7 @@ def __init__(
batch_size: int = 1,
shuffle: bool = False,
seed: Optional[int] = None,
return_sparse_X: bool = False,
X_format: Literal["dense", "csr", "coo"] = "dense",
soma_chunk_size: Optional[int] = None,
use_eager_fetch: bool = True,
) -> None:
Expand Down Expand Up @@ -433,11 +440,11 @@ def __init__(
The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using
:class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker
processes.
return_sparse_X:
Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. As ``X`` data is
very sparse, setting this to ``True`` will reduce memory usage, if the model supports use of sparse
:class:`torch.Tensor`\ s. Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still
experimental in PyTorch.
X_format:
Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. Must be one of
``"dense"``, ``"csr"``, or ``"coo"``. As ``X`` data is very sparse, setting this to ``"coo"`` or
``"csr"`` will reduce memory usage, if the model supports use of sparse :class:`torch.Tensor`\ s.
Defaults to ``"dense"``, since sparse :class:`torch.Tensor`\ s are still experimental in PyTorch.
soma_chunk_size:
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of
this class's behavior: 1) The maximum memory utilization, with larger values providing
Expand All @@ -463,7 +470,7 @@ def __init__(
self.var_query = var_query
self.obs_column_names = obs_column_names
self.batch_size = batch_size
self.return_sparse_X = return_sparse_X
self.X_format = X_format
self.soma_chunk_size = soma_chunk_size
self.use_eager_fetch = use_eager_fetch
self._stats = Stats()
Expand Down Expand Up @@ -545,7 +552,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]:
pytorch_logger.debug(f"Using {self.soma_chunk_size=}")

if (
self.return_sparse_X
self.X_format != "dense"
and torch.utils.data.get_worker_info()
and torch.utils.data.get_worker_info().num_workers > 0
):
Expand Down Expand Up @@ -583,7 +590,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]:
batch_size=self.batch_size,
encoders=self.obs_encoders,
stats=self._stats,
return_sparse_X=self.return_sparse_X,
X_format=self.X_format,
use_eager_fetch=self.use_eager_fetch,
shuffle_rng=self._shuffle_rng,
)
Expand Down
16 changes: 11 additions & 5 deletions api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
import sys
from typing import Callable, List, Optional, Sequence, Union
from typing import Callable, List, Literal, Optional, Sequence, Union
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -278,13 +278,16 @@ def test_batching__empty_query_result(soma_experiment: Experiment, use_eager_fet
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None:
@pytest.mark.parametrize("X_format", ("coo", "csr"))
def test_sparse_output__non_batched(
soma_experiment: Experiment, use_eager_fetch: bool, X_format: Literal["dense", "csr", "coo"]
) -> None:
exp_data_pipe = ExperimentDataPipe(
soma_experiment,
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
return_sparse_X=True,
X_format=X_format,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand All @@ -300,14 +303,17 @@ def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
def test_sparse_output__batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None:
@pytest.mark.parametrize("X_format", ("coo", "csr"))
def test_sparse_output__batched(
soma_experiment: Experiment, use_eager_fetch: bool, X_format: Literal["dense", "csr", "coo"]
) -> None:
exp_data_pipe = ExperimentDataPipe(
soma_experiment,
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
batch_size=3,
return_sparse_X=True,
X_format=X_format,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand Down
Loading