Skip to content

Commit

Permalink
optimized chunkwise writing + garbage collection + handling NaN writi…
Browse files Browse the repository at this point in the history
…ng bug
  • Loading branch information
Saiyam26 committed Sep 13, 2024
1 parent 406ce2a commit 1390a7d
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 36 deletions.
7 changes: 4 additions & 3 deletions scalr/data/preprocess/_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def fit(
"""
pass

def process_data(self, datapath: dict, sample_chunksize: int, dirpath: str):
def process_data(self, full_data: Union[AnnData, AnnCollection],
sample_chunksize: int, dirpath: str):
"""A function to process the entire data chunkwise and write the processed data
to disk.
Args:
datapath (str): Path to read the data from for transformation.
full_data (Union[AnnData, AnnCollection]): Full data for transformation.
sample_chunksize (int): Number of samples in one chunk.
dirpath (str): Path to write the data to.
"""
Expand All @@ -64,7 +65,7 @@ def process_data(self, datapath: dict, sample_chunksize: int, dirpath: str):
raise NotImplementedError(
'Preprocessing does not work without sample chunk size')

write_chunkwise_data(datapath,
write_chunkwise_data(full_data,
sample_chunksize,
dirpath,
transform=self.transform)
Expand Down
16 changes: 10 additions & 6 deletions scalr/data/split/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import os
from os import path
from typing import Union

from anndata import AnnData
from anndata.experimental import AnnCollection

import scalr
from scalr.utils import build_object
Expand Down Expand Up @@ -88,12 +92,13 @@ def check_splits(self, datapath: str, data_splits: dict, target: str):
self.event_logger.info(
f'{metadata[target].iloc[test_inds].value_counts()}\n')

def write_splits(self, full_datapath: str, data_split_indices: dict,
sample_chunksize: int, dirpath: int):
def write_splits(self, full_data: Union[AnnData, AnnCollection],
data_split_indices: dict, sample_chunksize: int,
dirpath: int):
"""THis function writes the train validation and test splits to the disk.
Args:
full_datapath (str): Full datapath of data to be split.
full_data (Union[AnnData, AnnCollection]): Full data to be split.
data_split_indices (dict): Indices of each split.
sample_chunksize (int): Number of samples to be written in one file.
dirpath (int): Path to write data into.
Expand All @@ -106,10 +111,9 @@ def write_splits(self, full_datapath: str, data_split_indices: dict,
if sample_chunksize:
split_dirpath = path.join(dirpath, split)
os.makedirs(split_dirpath, exist_ok=True)
write_chunkwise_data(full_datapath, sample_chunksize,
split_dirpath, data_split_indices[split])
write_chunkwise_data(full_data, sample_chunksize, split_dirpath,
data_split_indices[split])
else:
full_data = read_data(full_datapath)
filepath = path.join(dirpath, f'{split}.h5ad')
write_data(full_data[data_split_indices[split]], filepath)

Expand Down
24 changes: 16 additions & 8 deletions scalr/data_ingestion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
from os import path

import pandas as pd

from scalr.data.preprocess import build_preprocessor
from scalr.data.split import build_splitter
from scalr.utils import FlowLogger
Expand Down Expand Up @@ -51,6 +53,7 @@ def generate_train_val_test_split(self):
''')

full_datapath = self.data_config['train_val_test']['full_datapath']
self.full_data = read_data(full_datapath)
splitter_config = deepcopy(
self.data_config['train_val_test']['splitter_config'])
splitter, splitter_config = build_splitter(splitter_config)
Expand All @@ -74,10 +77,13 @@ def generate_train_val_test_split(self):
'train_val_test_split')
os.makedirs(train_val_test_split_dirpath, exist_ok=True)

splitter.write_splits(full_datapath, train_val_test_split_indices,
splitter.write_splits(self.full_data, train_val_test_split_indices,
self.sample_chunksize,
train_val_test_split_dirpath)

# Garbage collection
del self.full_data

self.data_config['train_val_test'][
'split_datapaths'] = train_val_test_split_dirpath

Expand Down Expand Up @@ -117,7 +123,7 @@ def preprocess_data(self):
self.sample_chunksize)
# Transform on train, val & test split.
for split in ['train', 'val', 'test']:
preprocessor.process_data(path.join(datapath, split),
preprocessor.process_data(read_data(path.join(datapath, split)),
self.sample_chunksize,
path.join(processed_datapath, split))

Expand All @@ -140,22 +146,24 @@ def generate_mappings(self):
for split in ['train', 'val', 'test']:
datapath = path.join(
self.data_config['train_val_test']['final_datapaths'], split)
datas.append(read_data(datapath))
datas.append(read_data(datapath).obs)
data = pd.concat(datas)

label_mappings = {}
for column_name in column_names:
label_mappings[column_name] = {}

id2label = []
for data in datas:
id2label += data.obs[column_name].astype(
'category').cat.categories.tolist()
id2label = sorted(
data[column_name].astype('category').cat.categories.tolist())

id2label = sorted(list(set(id2label)))
label2id = {id2label[i]: i for i in range(len(id2label))}
label_mappings[column_name]['id2label'] = id2label
label_mappings[column_name]['label2id'] = label2id

# Garbage collection
del datas
del data

write_data(label_mappings, path.join(self.datadir,
'label_mappings.json'))

Expand Down
12 changes: 9 additions & 3 deletions scalr/feature_extraction_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,19 @@ def write_top_features_subset_data(self, data_config: dict) -> dict:
feature_subset_datapath = path.join(self.dirpath, 'feature_subset_data')
os.makedirs(feature_subset_datapath, exist_ok=True)

for split in ['train', 'val', 'test']:
test_data = read_data(path.join(datapath, 'test'))
splits = {
'train': self.train_data,
'val': self.val_data,
'test': test_data
}

for split, split_data in splits.items():

split_datapath = path.join(datapath, split)
split_feature_subset_datapath = path.join(feature_subset_datapath,
split)
sample_chunksize = data_config.get('sample_chunksize')
write_chunkwise_data(split_datapath,
write_chunkwise_data(split_data,
sample_chunksize,
split_feature_subset_datapath,
feature_inds=self.top_features)
Expand Down
25 changes: 10 additions & 15 deletions scalr/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def write_data(data: Union[dict, AnnData, pd.DataFrame], filepath: str):
'`filepath` does not contain `json`, `yaml`, or `h5ad` file')


def write_chunkwise_data(datapath: str,
def write_chunkwise_data(full_data: Union[AnnData, AnnCollection],
sample_chunksize: int,
dirpath: str,
sample_inds: list[int] = None,
Expand All @@ -81,7 +81,7 @@ def write_chunkwise_data(datapath: str,
This function can also apply transformation on each chunk.
Args:
datapath (str): path/to/data to be written in chunks.
full_data (Union[AnnData, AnnCollection]): data to be written in chunks.
sample_chunksize (int): number of samples to be loaded at a time.
dirpath (str): path/to/directory to write the chunks of data.
sample_inds (list[int], optional): To be used in case of chunking
Expand All @@ -95,27 +95,22 @@ def write_chunkwise_data(datapath: str,
if not path.exists(dirpath):
os.makedirs(dirpath)

data = read_data(datapath)
if isinstance(data, AnnData) and feature_inds:
raise ValueError(
'TrainValTestSplit data for FeatureSubsetting must be AnnCollection'
)

if not sample_inds:
sample_inds = list(range(len(data)))
sample_inds = list(range(len(full_data)))

# Hacky fix for an AnnCollection working/bug.
# Hacky fixes for an AnnCollection working/bug.
if sample_chunksize >= len(sample_inds):
sample_chunksize = len(sample_inds) - 1

for i, (start) in enumerate(range(0, len(sample_inds), sample_chunksize)):
data = read_data(datapath)
for col in full_data.obs.columns:
full_data.obs[col] = full_data.obs[col].astype('category')

for i, (start) in enumerate(range(0, len(sample_inds), sample_chunksize)):
if feature_inds:
data = data[sample_inds[start:start + sample_chunksize],
feature_inds]
data = full_data[sample_inds[start:start + sample_chunksize],
feature_inds]
else:
data = data[sample_inds[start:start + sample_chunksize]]
data = full_data[sample_inds[start:start + sample_chunksize]]

if not isinstance(data, AnnData):
data = data.to_adata()
Expand Down
2 changes: 1 addition & 1 deletion scalr/utils/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_write_chunkwise_data():
dirpath = './tmp/chunked_data/'

# Writing fulldata in chunks.
write_chunkwise_data(fulldata_path,
write_chunkwise_data(read_data(fulldata_path),
sample_chunksize=sample_chunksize,
dirpath=dirpath)

Expand Down

0 comments on commit 1390a7d

Please sign in to comment.