diff --git a/scalr/data/preprocess/_preprocess.py b/scalr/data/preprocess/_preprocess.py index b53594c..68a60a2 100644 --- a/scalr/data/preprocess/_preprocess.py +++ b/scalr/data/preprocess/_preprocess.py @@ -50,12 +50,13 @@ def fit( """ pass - def process_data(self, datapath: dict, sample_chunksize: int, dirpath: str): + def process_data(self, full_data: Union[AnnData, AnnCollection], + sample_chunksize: int, dirpath: str): """A function to process the entire data chunkwise and write the processed data to disk. Args: - datapath (str): Path to read the data from for transformation. + full_data (Union[AnnData, AnnCollection]): Full data for transformation. sample_chunksize (int): Number of samples in one chunk. dirpath (str): Path to write the data to. """ @@ -64,7 +65,7 @@ def process_data(self, datapath: dict, sample_chunksize: int, dirpath: str): raise NotImplementedError( 'Preprocessing does not work without sample chunk size') - write_chunkwise_data(datapath, + write_chunkwise_data(full_data, sample_chunksize, dirpath, transform=self.transform) diff --git a/scalr/data/split/_split.py b/scalr/data/split/_split.py index fe35c75..f2ffe14 100644 --- a/scalr/data/split/_split.py +++ b/scalr/data/split/_split.py @@ -2,6 +2,10 @@ import os from os import path +from typing import Union + +from anndata import AnnData +from anndata.experimental import AnnCollection import scalr from scalr.utils import build_object @@ -88,12 +92,13 @@ def check_splits(self, datapath: str, data_splits: dict, target: str): self.event_logger.info( f'{metadata[target].iloc[test_inds].value_counts()}\n') - def write_splits(self, full_datapath: str, data_split_indices: dict, - sample_chunksize: int, dirpath: int): + def write_splits(self, full_data: Union[AnnData, AnnCollection], + data_split_indices: dict, sample_chunksize: int, + dirpath: int): """THis function writes the train validation and test splits to the disk. Args: - full_datapath (str): Full datapath of data to be split. + full_data (Union[AnnData, AnnCollection]): Full data to be split. data_split_indices (dict): Indices of each split. sample_chunksize (int): Number of samples to be written in one file. dirpath (int): Path to write data into. @@ -106,10 +111,9 @@ def write_splits(self, full_datapath: str, data_split_indices: dict, if sample_chunksize: split_dirpath = path.join(dirpath, split) os.makedirs(split_dirpath, exist_ok=True) - write_chunkwise_data(full_datapath, sample_chunksize, - split_dirpath, data_split_indices[split]) + write_chunkwise_data(full_data, sample_chunksize, split_dirpath, + data_split_indices[split]) else: - full_data = read_data(full_datapath) filepath = path.join(dirpath, f'{split}.h5ad') write_data(full_data[data_split_indices[split]], filepath) diff --git a/scalr/data_ingestion_pipeline.py b/scalr/data_ingestion_pipeline.py index 0e087c7..ee6c23e 100644 --- a/scalr/data_ingestion_pipeline.py +++ b/scalr/data_ingestion_pipeline.py @@ -4,6 +4,8 @@ import os from os import path +import pandas as pd + from scalr.data.preprocess import build_preprocessor from scalr.data.split import build_splitter from scalr.utils import FlowLogger @@ -51,6 +53,7 @@ def generate_train_val_test_split(self): ''') full_datapath = self.data_config['train_val_test']['full_datapath'] + self.full_data = read_data(full_datapath) splitter_config = deepcopy( self.data_config['train_val_test']['splitter_config']) splitter, splitter_config = build_splitter(splitter_config) @@ -74,10 +77,13 @@ def generate_train_val_test_split(self): 'train_val_test_split') os.makedirs(train_val_test_split_dirpath, exist_ok=True) - splitter.write_splits(full_datapath, train_val_test_split_indices, + splitter.write_splits(self.full_data, train_val_test_split_indices, self.sample_chunksize, train_val_test_split_dirpath) + # Garbage collection + del self.full_data + self.data_config['train_val_test'][ 'split_datapaths'] = train_val_test_split_dirpath @@ -117,7 +123,7 @@ def preprocess_data(self): self.sample_chunksize) # Transform on train, val & test split. for split in ['train', 'val', 'test']: - preprocessor.process_data(path.join(datapath, split), + preprocessor.process_data(read_data(path.join(datapath, split)), self.sample_chunksize, path.join(processed_datapath, split)) @@ -140,22 +146,24 @@ def generate_mappings(self): for split in ['train', 'val', 'test']: datapath = path.join( self.data_config['train_val_test']['final_datapaths'], split) - datas.append(read_data(datapath)) + datas.append(read_data(datapath).obs) + data = pd.concat(datas) label_mappings = {} for column_name in column_names: label_mappings[column_name] = {} - id2label = [] - for data in datas: - id2label += data.obs[column_name].astype( - 'category').cat.categories.tolist() + id2label = sorted( + data[column_name].astype('category').cat.categories.tolist()) - id2label = sorted(list(set(id2label))) label2id = {id2label[i]: i for i in range(len(id2label))} label_mappings[column_name]['id2label'] = id2label label_mappings[column_name]['label2id'] = label2id + # Garbage collection + del datas + del data + write_data(label_mappings, path.join(self.datadir, 'label_mappings.json')) diff --git a/scalr/feature_extraction_pipeline.py b/scalr/feature_extraction_pipeline.py index 91fa4c6..fa0a118 100644 --- a/scalr/feature_extraction_pipeline.py +++ b/scalr/feature_extraction_pipeline.py @@ -179,13 +179,19 @@ def write_top_features_subset_data(self, data_config: dict) -> dict: feature_subset_datapath = path.join(self.dirpath, 'feature_subset_data') os.makedirs(feature_subset_datapath, exist_ok=True) - for split in ['train', 'val', 'test']: + test_data = read_data(path.join(datapath, 'test')) + splits = { + 'train': self.train_data, + 'val': self.val_data, + 'test': test_data + } + + for split, split_data in splits.items(): - split_datapath = path.join(datapath, split) split_feature_subset_datapath = path.join(feature_subset_datapath, split) sample_chunksize = data_config.get('sample_chunksize') - write_chunkwise_data(split_datapath, + write_chunkwise_data(split_data, sample_chunksize, split_feature_subset_datapath, feature_inds=self.top_features) diff --git a/scalr/utils/file_utils.py b/scalr/utils/file_utils.py index 921b98f..926bc96 100644 --- a/scalr/utils/file_utils.py +++ b/scalr/utils/file_utils.py @@ -69,7 +69,7 @@ def write_data(data: Union[dict, AnnData, pd.DataFrame], filepath: str): '`filepath` does not contain `json`, `yaml`, or `h5ad` file') -def write_chunkwise_data(datapath: str, +def write_chunkwise_data(full_data: Union[AnnData, AnnCollection], sample_chunksize: int, dirpath: str, sample_inds: list[int] = None, @@ -81,7 +81,7 @@ def write_chunkwise_data(datapath: str, This function can also apply transformation on each chunk. Args: - datapath (str): path/to/data to be written in chunks. + full_data (Union[AnnData, AnnCollection]): data to be written in chunks. sample_chunksize (int): number of samples to be loaded at a time. dirpath (str): path/to/directory to write the chunks of data. sample_inds (list[int], optional): To be used in case of chunking @@ -95,27 +95,22 @@ def write_chunkwise_data(datapath: str, if not path.exists(dirpath): os.makedirs(dirpath) - data = read_data(datapath) - if isinstance(data, AnnData) and feature_inds: - raise ValueError( - 'TrainValTestSplit data for FeatureSubsetting must be AnnCollection' - ) - if not sample_inds: - sample_inds = list(range(len(data))) + sample_inds = list(range(len(full_data))) - # Hacky fix for an AnnCollection working/bug. + # Hacky fixes for an AnnCollection working/bug. if sample_chunksize >= len(sample_inds): sample_chunksize = len(sample_inds) - 1 - for i, (start) in enumerate(range(0, len(sample_inds), sample_chunksize)): - data = read_data(datapath) + for col in full_data.obs.columns: + full_data.obs[col] = full_data.obs[col].astype('category') + for i, (start) in enumerate(range(0, len(sample_inds), sample_chunksize)): if feature_inds: - data = data[sample_inds[start:start + sample_chunksize], - feature_inds] + data = full_data[sample_inds[start:start + sample_chunksize], + feature_inds] else: - data = data[sample_inds[start:start + sample_chunksize]] + data = full_data[sample_inds[start:start + sample_chunksize]] if not isinstance(data, AnnData): data = data.to_adata() diff --git a/scalr/utils/test_file_utils.py b/scalr/utils/test_file_utils.py index 2e70051..45fb397 100644 --- a/scalr/utils/test_file_utils.py +++ b/scalr/utils/test_file_utils.py @@ -31,7 +31,7 @@ def test_write_chunkwise_data(): dirpath = './tmp/chunked_data/' # Writing fulldata in chunks. - write_chunkwise_data(fulldata_path, + write_chunkwise_data(read_data(fulldata_path), sample_chunksize=sample_chunksize, dirpath=dirpath)