Skip to content

Commit

Permalink
fix bug in combine datasets balanced
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Sep 27, 2023
1 parent 51b0fa3 commit 954f0ee
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions src/sparcscore/ml/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from scipy.sparse import csr_matrix
import numpy as np
import pandas as pd
import torch
from math import floor

Expand Down Expand Up @@ -39,27 +40,47 @@ def combine_datasets_balanced(list_of_datasets, class_labels, train_per_class, v
cells_per_class = np.sum(mat,axis=0)
normalized = mat / cells_per_class
dataset_fraction = np.sum(normalized,axis=1)
print(dataset_fraction)

# Initialize empty lists to store the combined train, validation, and test datasets
train_dataset = []
test_dataset = []
val_dataset = []

for dataset, label, fraction in zip(list_of_datasets, class_labels, dataset_fraction):
print(dataset, label, fraction)
train_size = floor(train_per_class*fraction)
test_size = floor(test_per_class*fraction)
val_size = floor(val_per_class*fraction)

residual_size = len(dataset) - train_size - test_size - val_size

if(residual_size < 0):
raise ValueError(f"Dataset with length {len(dataset)} is to small to be split into test set of size {test_size} and train set of size {train_size} and validation set of size {val_size}. Use a smaller test and trainset.")

train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size])
train_dataset.append(train)
test_dataset.append(test)
val_dataset.append(val)
#check to make sure we have more than one occurance of a dataset (otherwise it will throw an error)
if np.sum(pd.Series(class_list).value_counts() > 1) == 0:
for dataset, label, fraction in zip(list_of_datasets, class_labels, dataset_fraction):
print(dataset, label, 1)
train_size = floor(train_per_class)
test_size = floor(test_per_class)
val_size = floor(val_per_class)

residual_size = len(dataset) - train_size - test_size - val_size

if(residual_size < 0):
raise ValueError(f"Dataset with length {len(dataset)} is to small to be split into test set of size {test_size} and train set of size {train_size} and validation set of size {val_size}. Use a smaller test and trainset.")

train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size])
train_dataset.append(train)
test_dataset.append(test)
val_dataset.append(val)
else:

for dataset, label, fraction in zip(list_of_datasets, class_labels, dataset_fraction):
print(dataset, label, fraction)
train_size = floor(train_per_class*fraction)
test_size = floor(test_per_class*fraction)
val_size = floor(val_per_class*fraction)

residual_size = len(dataset) - train_size - test_size - val_size

if(residual_size < 0):
raise ValueError(f"Dataset with length {len(dataset)} is to small to be split into test set of size {test_size} and train set of size {train_size} and validation set of size {val_size}. Use a smaller test and trainset.")

train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size])
train_dataset.append(train)
test_dataset.append(test)
val_dataset.append(val)

# Convert the combined datasets into torch.utils.data.Dataset objects
train_dataset = torch.utils.data.ConcatDataset(train_dataset)
Expand Down

0 comments on commit 954f0ee

Please sign in to comment.