Skip to content

Commit

Permalink
adding segmentation stratification
Browse files Browse the repository at this point in the history
  • Loading branch information
alishibli97 committed Oct 10, 2024
1 parent 33053c2 commit 927eec6
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions pangaea/utils/subset_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from tqdm import tqdm
import numpy as np


# Function to calculate class distributions for classification with a progress bar
def calculate_class_distributions(dataset, num_classes):
class_distributions = []
Expand Down Expand Up @@ -50,26 +49,36 @@ def bin_regression_distributions(regression_distributions, num_bins=3, logger=No
return binned_distributions


# Function to perform stratification for classification and return only the indices
# Updated function to perform stratification for classification and return only the indices, with even bin selection
def stratify_classification_dataset_indices(dataset, num_classes, label_fraction=1.0, num_bins=3, logger=None):
# Step 1: Calculate class distributions with progress tracking
class_distributions = calculate_class_distributions(dataset, num_classes)

# Step 2: Bin the class distributions
binned_distributions = bin_class_distributions(class_distributions, num_bins=num_bins, logger=logger)

# Step 3: Combine the bins to use for stratification

# Step 3: Prep a dictionary to hold indices for each bin combination
indices_per_bin = {}

# Combine the bins for each class to create unique bin identifiers
combined_bins = np.apply_along_axis(lambda row: ''.join(map(str, row)), axis=1, arr=binned_distributions)

# Step 4: Select a subset of labeled data with progress tracking
num_labeled = int(len(dataset) * label_fraction)
# Populate the dictionary with indices based on combined bin identifiers
for idx, bin_id in enumerate(combined_bins):
if bin_id not in indices_per_bin:
indices_per_bin[bin_id] = []
indices_per_bin[bin_id].append(idx)

# Sort the indices based on combined bins to preserve class distribution
sorted_indices = np.argsort(combined_bins)
labeled_idx = sorted_indices[:num_labeled]
unlabeled_idx = sorted_indices[num_labeled:]
# Step 4: Select a proportion of indices from each bin
selected_idx = []
for bin_id, indices in indices_per_bin.items():
num_to_select = int(max(1, len(indices) * label_fraction)) # Ensure at least one index is selected
selected_idx.extend(np.random.choice(indices, num_to_select, replace=False))

return labeled_idx, unlabeled_idx
# Step 5: Determine the remaining indices not selected
other_idx = list(set(range(len(dataset))) - set(selected_idx))

return selected_idx, other_idx


# Function to perform stratification for regression and return only the indices
Expand All @@ -87,25 +96,16 @@ def stratify_regression_dataset_indices(dataset, label_fraction=1.0, num_bins=3,
for index, bin_index in enumerate(binned_distributions):
if bin_index in indices_per_bin:
indices_per_bin[bin_index].append(index)

# Step 5: Select fraction of indices from each bin
selected_idx = []
for bin_index, indices in indices_per_bin.items():
num_to_select = int(max(1, len(indices)*label_fraction) ) # To ensure at least one index is selected
selected_idx.extend(np.random.choice(indices, num_to_select, replace=False))
num_to_select = int(max(1, len(indices) * label_fraction)) # Ensure at least one index is selected
selected_idx.extend(np.random.choice(indices, num_to_select, replace=False))

other_idx = list(set(range(len(dataset))) - set(selected_idx))

return selected_idx, other_idx

# # Step 3: Sort the indices based on binned distributions for stratification
# sorted_indices = np.argsort(binned_distributions)

# # Step 4: Select a subset of labeled data with progress tracking
# num_labeled = int(len(dataset) * label_fraction)
# labeled_idx = sorted_indices[:num_labeled]
# unlabeled_idx = sorted_indices[num_labeled:]

# return labeled_idx, unlabeled_idx
return selected_idx, other_idx


# Function to get subset indices based on the strategy, supporting both classification and regression
Expand All @@ -128,3 +128,5 @@ def get_subset_indices(dataset, strategy="random", label_fraction=0.5, num_bins=
)

return indices


0 comments on commit 927eec6

Please sign in to comment.