Skip to content

Commit

Permalink
comments addresed
Browse files Browse the repository at this point in the history
  • Loading branch information
Saiyam26 committed Sep 16, 2024
1 parent 9606fbe commit 38dca86
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
19 changes: 10 additions & 9 deletions scalr/data_ingestion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def preprocess_data(self):
self.sample_chunksize)
# Transform on train, val & test split.
for split in ['train', 'val', 'test']:
preprocessor.process_data(read_data(path.join(datapath, split)),
self.sample_chunksize,
split_data = read_data(path.join(datapath, split))
preprocessor.process_data(split_data, self.sample_chunksize,
path.join(processed_datapath, split))

datapath = processed_datapath
Expand All @@ -142,27 +142,28 @@ def generate_mappings(self):
path.join(self.data_config['train_val_test']['final_datapaths'],
'val')).obs.columns

datas = []
data_obs = []
for split in ['train', 'val', 'test']:
datapath = path.join(
self.data_config['train_val_test']['final_datapaths'], split)
datas.append(read_data(datapath).obs)
data = pd.concat(datas)
split_data_obs = read_data(datapath).obs
data_obs.append(split_data_obs)
full_data_obs = pd.concat(data_obs)

label_mappings = {}
for column_name in column_names:
label_mappings[column_name] = {}

id2label = sorted(
data[column_name].astype('category').cat.categories.tolist())
id2label = sorted(full_data_obs[column_name].astype(
'category').cat.categories.tolist())

label2id = {id2label[i]: i for i in range(len(id2label))}
label_mappings[column_name]['id2label'] = id2label
label_mappings[column_name]['label2id'] = label2id

# Garbage collection
del datas
del data
del data_obs
del full_data_obs

write_data(label_mappings, path.join(self.datadir,
'label_mappings.json'))
Expand Down
3 changes: 2 additions & 1 deletion scalr/utils/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def test_write_chunkwise_data():
dirpath = './tmp/chunked_data/'

# Writing fulldata in chunks.
write_chunkwise_data(read_data(fulldata_path),
full_data = read_data(fulldata_path)
write_chunkwise_data(full_data,
sample_chunksize=sample_chunksize,
dirpath=dirpath)

Expand Down

0 comments on commit 38dca86

Please sign in to comment.