diff --git a/pangaea/datasets/mados.py b/pangaea/datasets/mados.py index 42ac6f1..3d4ff16 100644 --- a/pangaea/datasets/mados.py +++ b/pangaea/datasets/mados.py @@ -1,5 +1,6 @@ import os -import time import pathlib +import time +import pathlib import urllib.request import urllib.error import zipfile diff --git a/pangaea/utils/subset_sampler.py b/pangaea/utils/subset_sampler.py index ffbf763..a8f3796 100644 --- a/pangaea/utils/subset_sampler.py +++ b/pangaea/utils/subset_sampler.py @@ -40,14 +40,16 @@ def calculate_regression_distributions(dataset: GeoFMDataset|GeoFMSubset): return np.array(distributions) -# Function to bin class distributions +# Function to bin class distributions using ceil def bin_class_distributions(class_distributions, num_bins=3, logger=None): - logger.info(f"Class distributions are being binned into {num_bins} categories") + logger.info(f"Class distributions are being binned into {num_bins} categories using ceil") - binned_distributions = np.digitize(class_distributions, np.linspace(0, 1, num_bins+1)) - 1 + bin_edges = np.linspace(0, 1, num_bins + 1)[1] + binned_distributions = np.ceil(class_distributions / bin_edges).astype(int) - 1 return binned_distributions + # Function to bin regression distributions def bin_regression_distributions(regression_distributions, num_bins=3, logger=None): logger.info(f"Regression distributions are being binned into {num_bins} categories")