From 38dca865a7d07b014ba545d17b3af4fb6c2f2702 Mon Sep 17 00:00:00 2001 From: Saiyam26 Date: Mon, 16 Sep 2024 13:07:47 +0530 Subject: [PATCH] comments addresed --- scalr/data_ingestion_pipeline.py | 19 ++++++++++--------- scalr/utils/test_file_utils.py | 3 ++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/scalr/data_ingestion_pipeline.py b/scalr/data_ingestion_pipeline.py index ee6c23e..2aad9dd 100644 --- a/scalr/data_ingestion_pipeline.py +++ b/scalr/data_ingestion_pipeline.py @@ -123,8 +123,8 @@ def preprocess_data(self): self.sample_chunksize) # Transform on train, val & test split. for split in ['train', 'val', 'test']: - preprocessor.process_data(read_data(path.join(datapath, split)), - self.sample_chunksize, + split_data = read_data(path.join(datapath, split)) + preprocessor.process_data(split_data, self.sample_chunksize, path.join(processed_datapath, split)) datapath = processed_datapath @@ -142,27 +142,28 @@ def generate_mappings(self): path.join(self.data_config['train_val_test']['final_datapaths'], 'val')).obs.columns - datas = [] + data_obs = [] for split in ['train', 'val', 'test']: datapath = path.join( self.data_config['train_val_test']['final_datapaths'], split) - datas.append(read_data(datapath).obs) - data = pd.concat(datas) + split_data_obs = read_data(datapath).obs + data_obs.append(split_data_obs) + full_data_obs = pd.concat(data_obs) label_mappings = {} for column_name in column_names: label_mappings[column_name] = {} - id2label = sorted( - data[column_name].astype('category').cat.categories.tolist()) + id2label = sorted(full_data_obs[column_name].astype( + 'category').cat.categories.tolist()) label2id = {id2label[i]: i for i in range(len(id2label))} label_mappings[column_name]['id2label'] = id2label label_mappings[column_name]['label2id'] = label2id # Garbage collection - del datas - del data + del data_obs + del full_data_obs write_data(label_mappings, path.join(self.datadir, 'label_mappings.json')) diff --git a/scalr/utils/test_file_utils.py b/scalr/utils/test_file_utils.py index 45fb397..3fa8dd6 100644 --- a/scalr/utils/test_file_utils.py +++ b/scalr/utils/test_file_utils.py @@ -31,7 +31,8 @@ def test_write_chunkwise_data(): dirpath = './tmp/chunked_data/' # Writing fulldata in chunks. - write_chunkwise_data(read_data(fulldata_path), + full_data = read_data(fulldata_path) + write_chunkwise_data(full_data, sample_chunksize=sample_chunksize, dirpath=dirpath)