From 26b2495e73c603e6e60bfe61cc8754e82e674824 Mon Sep 17 00:00:00 2001 From: Georges Le Bellier Date: Wed, 18 Sep 2024 16:49:55 +0200 Subject: [PATCH] 10 add pastis dataset to v2 (#25) * Add pastis v1 * Adapt output to v2 * Add get_splits (#10) * Add split as attribute to match other datasets format * Minor fix to deal with str dates * Add config file for PASTIS (#10) * Minor changes * Update output shapes and stats computation * Add PASTIS dataset metadata * Minor changes on PASTIS dataset and add running data stats * Clean code but facing register issue (#10) * Minor fix * Minor updates * Add pastis import * Debugging preprocessing for timeseries * Update pastis configs * Update PASTIS target * Fix collate_fn issue * Update timeserie sampling * Update datasets inits * Add single image PASTIS * remove a repetitive line * use relative dataset root path --- configs/augmentations/pastis_seg.yaml | 11 + configs/datasets/pastis.yaml | 110 ++++++++ configs/datasets/pastis_si.yaml | 110 ++++++++ datasets/__init__.py | 13 +- datasets/pastis.py | 387 ++++++++++++++++++++++++++ engine/data_preprocessor.py | 66 +++-- utils/compute_data_statistics.py | 71 +++-- 7 files changed, 712 insertions(+), 56 deletions(-) create mode 100644 configs/augmentations/pastis_seg.yaml create mode 100644 configs/datasets/pastis.yaml create mode 100644 configs/datasets/pastis_si.yaml create mode 100644 datasets/pastis.py diff --git a/configs/augmentations/pastis_seg.yaml b/configs/augmentations/pastis_seg.yaml new file mode 100644 index 00000000..8b74e954 --- /dev/null +++ b/configs/augmentations/pastis_seg.yaml @@ -0,0 +1,11 @@ +train: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + ResizeToEncoder: ~ + # RandomFlip: + # ud_probability: 0.3 + # lr_probability: 0.3 +test: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + ResizeToEncoder: ~ diff --git a/configs/datasets/pastis.yaml b/configs/datasets/pastis.yaml new file mode 100644 index 00000000..5640f252 --- /dev/null +++ b/configs/datasets/pastis.yaml @@ -0,0 +1,110 @@ +dataset_name: Pastis +root_path: ./data/PASTIS-HD +download_url: null +auto_download: False + +img_size: 128 +multi_temporal: 6 +multi_modal: True +limited_label: False + +# classes +ignore_index: 0 +num_classes: 20 +classes: + - Background + - Meadow + - Soft Winter Wheat + - Corn + - Winter Barley + - Winter Rapeseed + - Spring Barley + - Sunflower + - Grapevine + - Beet + - Winter Triticale + - Winter Durum Wheat + - Fruits, Vegetables, Flowers + - Potatoes + - Leguminous Fodder + - Soybeans + - Orchard + - Mixed Cereal + - Sorghum + - Void Label +distribution: + - 0.00000 + - 0.25675 + - 0.06733 + - 0.10767 + - 0.02269 + - 0.01451 + - 0.00745 + - 0.01111 + - 0.08730 + - 0.00715 + - 0.00991 + - 0.01398 + - 0.02149 + - 0.00452 + - 0.02604 + - 0.00994 + - 0.02460 + - 0.00696 + - 0.00580 + - 0.29476 + +bands: + optical: + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B11 + - B12 + sar: + - VV + - VH + - VV-VH +data_mean: + optical: + - 1161.6764 + - 1371.4307 + - 1423.4067 + - 1759.7251 + - 2714.5259 + - 3055.8376 + - 3197.8960 + - 3313.3577 + - 2415.9675 + - 1626.8431 + sar: + - -10.9433 + - -17.3600 + - 6.4167 +data_std: + optical: + - 2045.0698 + - 1983.1763 + - 2060.7969 + - 1968.8173 + - 1867.2159 + - 1885.1361 + - 1897.5105 + - 1885.1636 + - 1542.7665 + - 1375.2511 + sar: + - 3.3847 + - 3.3727 + - 3.3874 +data_min: + optical: [-10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.] + sar: [-39.5312, -43.1250, -20.6562] +data_max: + optical: [22256., 21891., 22626., 21814., 20134., 19282., 18957., 18482., 16935., 14668.] + sar: [34.8750, 27.3125, 51.9062] diff --git a/configs/datasets/pastis_si.yaml b/configs/datasets/pastis_si.yaml new file mode 100644 index 00000000..43d990c5 --- /dev/null +++ b/configs/datasets/pastis_si.yaml @@ -0,0 +1,110 @@ +dataset_name: Pastis +root_path: ./data/PASTIS-HD +download_url: null +auto_download: False + +img_size: 128 +multi_temporal: 1 +multi_modal: True +limited_label: False + +# classes +ignore_index: 0 +num_classes: 20 +classes: + - Background + - Meadow + - Soft Winter Wheat + - Corn + - Winter Barley + - Winter Rapeseed + - Spring Barley + - Sunflower + - Grapevine + - Beet + - Winter Triticale + - Winter Durum Wheat + - Fruits, Vegetables, Flowers + - Potatoes + - Leguminous Fodder + - Soybeans + - Orchard + - Mixed Cereal + - Sorghum + - Void Label +distribution: + - 0.00000 + - 0.25675 + - 0.06733 + - 0.10767 + - 0.02269 + - 0.01451 + - 0.00745 + - 0.01111 + - 0.08730 + - 0.00715 + - 0.00991 + - 0.01398 + - 0.02149 + - 0.00452 + - 0.02604 + - 0.00994 + - 0.02460 + - 0.00696 + - 0.00580 + - 0.29476 + +bands: + optical: + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B11 + - B12 + sar: + - VV + - VH + - VV-VH +data_mean: + optical: + - 1161.6764 + - 1371.4307 + - 1423.4067 + - 1759.7251 + - 2714.5259 + - 3055.8376 + - 3197.8960 + - 3313.3577 + - 2415.9675 + - 1626.8431 + sar: + - -10.9433 + - -17.3600 + - 6.4167 +data_std: + optical: + - 2045.0698 + - 1983.1763 + - 2060.7969 + - 1968.8173 + - 1867.2159 + - 1885.1361 + - 1897.5105 + - 1885.1636 + - 1542.7665 + - 1375.2511 + sar: + - 3.3847 + - 3.3727 + - 3.3874 +data_min: + optical: [-10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.] + sar: [-39.5312, -43.1250, -20.6562] +data_max: + optical: [22256., 21891., 22626., 21814., 20134., 19282., 18957., 18482., 16935., 14668.] + sar: [34.8750, 27.3125, 51.9062] diff --git a/datasets/__init__.py b/datasets/__init__.py index 026a3837..7d85797c 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -1,10 +1,11 @@ -from .mados import MADOS -from .hlsburnscars import HLSBurnScars -from .sen1floods11 import Sen1Floods11 -from .xview2 import xView2 +from .ai4smallfarms import AI4SmallFarms from .biomassters import BioMassters from .croptypemapping import CropTypeMappingSouthSudan -from .ai4smallfarms import AI4SmallFarms -from .spacenet7 import SN7MAPPING, SN7CD from .fivebillionpixels import FiveBillionPixels +from .hlsburnscars import HLSBurnScars +from .mados import MADOS +from .pastis import Pastis +from .sen1floods11 import Sen1Floods11 +from .spacenet7 import SN7CD, SN7MAPPING from .utae_dynamicen import DynamicEarthNet +from .xview2 import xView2 diff --git a/datasets/pastis.py b/datasets/pastis.py new file mode 100644 index 00000000..5313a951 --- /dev/null +++ b/datasets/pastis.py @@ -0,0 +1,387 @@ +### +# Modified version of the PASTIS-HD dataset +# original code https://github.com/gastruc/OmniSat/blob/main/src/data/Pastis.py +### + +import json +import os +from datetime import datetime + +import geopandas as gpd +import numpy as np +import pandas as pd +import rasterio +import torch +from einops import rearrange +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from utils.registry import DATASET_REGISTRY + + +def prepare_dates(date_dict, reference_date): + """Date formating.""" + if type(date_dict) == str: + date_dict = json.loads(date_dict) + d = pd.DataFrame().from_dict(date_dict, orient="index") + d = d[0].apply( + lambda x: ( + datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:])) + - reference_date + ).days + ) + return torch.tensor(d.values) + + +def split_image(image_tensor, nb_split, id): + """ + Split the input image tensor into four quadrants based on the integer i. + To use if Pastis data does not fit in your GPU memory. + Returns the corresponding quadrant based on the value of i + """ + if nb_split == 1: + return image_tensor + i1 = id // nb_split + i2 = id % nb_split + height, width = image_tensor.shape[-2:] + half_height = height // nb_split + half_width = width // nb_split + if image_tensor.dim() == 4: + return image_tensor[ + :, + :, + i1 * half_height : (i1 + 1) * half_height, + i2 * half_width : (i2 + 1) * half_width, + ].float() + if image_tensor.dim() == 3: + return image_tensor[ + :, + i1 * half_height : (i1 + 1) * half_height, + i2 * half_width : (i2 + 1) * half_width, + ].float() + if image_tensor.dim() == 2: + return image_tensor[ + i1 * half_height : (i1 + 1) * half_height, + i2 * half_width : (i2 + 1) * half_width, + ].float() + + +@DATASET_REGISTRY.register() +class Pastis(Dataset): + def __init__( + self, + cfg: OmegaConf, + split: str, + is_train: bool = True, + ): + """ + Initializes the dataset. + Args: + path (str): path to the dataset + modalities (list): list of modalities to use + folds (list): list of folds to use + reference_date (str date): reference date for the data + nb_split (int): number of splits from one observation + num_classes (int): number of classes + """ + super(Pastis, self).__init__() + + if split == "train": + folds = [1, 2, 3] + elif split == "val": + folds = [4] + elif split == "test": + folds = [5] + + self.split = split + self.path = cfg["root_path"] + self.data_mean = cfg["data_mean"] + self.data_std = cfg["data_std"] + self.data_min = cfg["data_min"] + self.data_max = cfg["data_max"] + self.classes = cfg["classes"] + self.class_num = len(self.classes) + self.grid_size = cfg["multi_temporal"] + self.modalities = ["s2", "aerial", "s1-asc"] + self.nb_split = 1 + + reference_date = "2018-09-01" + self.reference_date = datetime(*map(int, reference_date.split("-"))) + + self.meta_patch = gpd.read_file(os.path.join(self.path, "metadata.geojson")) + + self.num_classes = 20 + + if folds is not None: + self.meta_patch = pd.concat( + [self.meta_patch[self.meta_patch["Fold"] == f] for f in folds] + ) + + def __getitem__(self, i): + """ + Returns an item from the dataset. + Args: + i (int): index of the item + Returns: + dict: dictionary with keys "label", "name" and the other corresponding to the modalities used + """ + line = self.meta_patch.iloc[i // (self.nb_split * self.nb_split)] + name = line["ID_PATCH"] + part = i % (self.nb_split * self.nb_split) + label = torch.from_numpy( + np.load( + os.path.join(self.path, "ANNOTATIONS/TARGET_" + str(name) + ".npy") + )[0].astype(np.int32) + ) + # label = torch.unique(split_image(label, self.nb_split, part)).long() + # label = torch.sum( + # torch.nn.functional.one_hot(label, num_classes=self.num_classes), dim=0 + # ) + # label = label[1:-1] # remove Background and Void classes + output = {"label": label, "name": name} + + for modality in self.modalities: + if modality == "aerial": + with rasterio.open( + os.path.join( + self.path, + "DATA_SPOT/PASTIS_SPOT6_RVB_1M00_2019/SPOT6_RVB_1M00_2019_" + + str(name) + + ".tif", + ) + ) as f: + output["aerial"] = split_image( + torch.FloatTensor(f.read()), self.nb_split, part + ) + elif modality == "s1-median": + modality_name = "s1a" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + out, _ = torch.median(images, dim=0) + output[modality] = out + elif modality == "s2-median": + modality_name = "s2" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + out, _ = torch.median(images, dim=0) + output[modality] = out + elif modality == "s1-4season-median": + modality_name = "s1a" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + dates = prepare_dates( + line["-".join(["dates", modality_name.upper()])], + self.reference_date, + ) + l = [] + for i in range(4): + mask = (dates >= 92 * i) & (dates < 92 * (i + 1)) + if sum(mask) > 0: + r, _ = torch.median(images[mask], dim=0) + l.append(r) + else: + l.append( + torch.zeros( + (images.shape[1], images.shape[-2], images.shape[-1]) + ) + ) + output[modality] = torch.cat(l) + elif modality == "s2-4season-median": + modality_name = "s2" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + dates = prepare_dates( + line["-".join(["dates", modality_name.upper()])], + self.reference_date, + ) + l = [] + for i in range(4): + mask = (dates >= 92 * i) & (dates < 92 * (i + 1)) + if sum(mask) > 0: + r, _ = torch.median(images[mask], dim=0) + l.append(r) + else: + l.append( + torch.zeros( + (images.shape[1], images.shape[-2], images.shape[-1]) + ) + ) + output[modality] = torch.cat(l) + else: + if len(modality) > 3: + modality_name = modality[:2] + modality[3] + output[modality] = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ) + output["_".join([modality, "dates"])] = prepare_dates( + line["-".join(["dates", modality_name.upper()])], + self.reference_date, + ) + else: + output[modality] = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality.upper()), + "{}_{}.npy".format(modality.upper(), name), + ) + ) + ), + self.nb_split, + part, + ) + output["_".join([modality, "dates"])] = prepare_dates( + line["-".join(["dates", modality.upper()])], self.reference_date + ) + N = len(output[modality]) + if N > 50: + random_indices = torch.randperm(N)[:50] + output[modality] = output[modality][random_indices] + output["_".join([modality, "dates"])] = output[ + "_".join([modality, "dates"]) + ][random_indices] + + optical_ts = rearrange(output["s2"], "t c h w -> c t h w") + sar_ts = rearrange(output["s1-asc"], "t c h w -> c t h w") + + if self.grid_size == 1: + # we only take the last frame + optical_ts = optical_ts[:, -1] + sar_ts = sar_ts[:, -1] + else: + # select evenly spaced samples + optical_indexes = torch.linspace( + 0, optical_ts.shape[1] - 1, self.grid_size, dtype=torch.long + ) + sar_indexes = torch.linspace( + 0, sar_ts.shape[1] - 1, self.grid_size, dtype=torch.long + ) + + optical_ts = optical_ts[:, optical_indexes] + sar_ts = sar_ts[:, sar_indexes] + + return { + "image": { + "optical": optical_ts.to(torch.float32), + "sar": sar_ts.to(torch.float32), + }, + "target": output["label"], + "metadata": {}, + } + + def __len__(self) -> int: + return len(self.meta_patch) * self.nb_split * self.nb_split + + @staticmethod + def get_splits(dataset_config): + dataset_train = Pastis(cfg=dataset_config, split="train", is_train=True) + dataset_val = Pastis(cfg=dataset_config, split="val", is_train=False) + dataset_test = Pastis(cfg=dataset_config, split="test", is_train=False) + return dataset_train, dataset_val, dataset_test + + @staticmethod + def download(dataset_config: dict, silent=False): + pass + + +if __name__ == "__main__": + class_prob = { + "Background": 0.0, + "Meadow": 31292, + "Soft Winter Wheat": 8206, + "Corn": 13123, + "Winter Barley": 2766, + "Winter Rapeseed": 1769, + "Spring Barley": 908, + "Sunflower": 1355, + "Grapevine": 10640, + "Beet": 871, + "Winter Triticale": 1208, + "Winter Durum Wheat": 1704, + "Fruits, Vegetables, Flowers": 2619, + "Potatoes": 551, + "Leguminous Fodder": 3174, + "Soybeans": 1212, + "Orchard": 2998, + "Mixed Cereal": 848, + "Sorghum": 707, + "Void Label": 35924, + } + + # get the class weights + class_weights = np.array([class_prob[key] for key in class_prob.keys()]) + class_weights = class_weights / class_weights.sum() + print("Class weights: ") + for i, key in enumerate(class_prob.keys()): + print(key, "->", class_weights[i]) + print("_" * 100) + + cfg = { + "root_path": "/share/DEEPLEARNING/datasets/PASTIS-HD", + "data_mean": None, + "data_std": None, + "classes": { + "0": "Background", + "1": "Meadow", + }, + "data_min": 0, + "data_max": 1, + } + + dataset = Pastis(cfg, "train", is_train=True) + train_dataset, val_dataset, test_dataset = Pastis.get_splits(cfg) diff --git a/engine/data_preprocessor.py b/engine/data_preprocessor.py index 7eacfc8d..e8867003 100644 --- a/engine/data_preprocessor.py +++ b/engine/data_preprocessor.py @@ -1,17 +1,13 @@ -import random - +import logging import math - -import torch -import torch.nn.functional as F -import torchvision.transforms as T +import random from typing import Callable import numpy as np -import logging - import omegaconf - +import torch +import torch.nn.functional as F +import torchvision.transforms as T from utils.registry import AUGMENTER_REGISTRY @@ -20,7 +16,7 @@ def get_collate_fn(cfg: omegaconf.DictConfig) -> Callable: modalities = cfg.encoder.input_bands.keys() def collate_fn( - batch: dict[dict[str, torch.Tensor]] + batch: dict[dict[str, torch.Tensor]], ) -> dict[dict[str, torch.Tensor]]: """Collate function for torch DataLoader args: @@ -153,9 +149,19 @@ def __init__(self, cfg, modality): self.input_bands = getattr(cfg.encoder.input_bands, modality, []) self.encoder_name = cfg.encoder.encoder_name - self.used_bands_mask = torch.tensor([b in self.input_bands for b in self.dataset_bands], dtype=torch.bool) - self.avail_bands_mask = torch.tensor([b in self.dataset_bands for b in self.input_bands], dtype=torch.bool) - self.avail_bands_indices = torch.tensor([self.dataset_bands.index(b) if b in self.dataset_bands else -1 for b in self.input_bands], dtype=torch.long) + self.used_bands_mask = torch.tensor( + [b in self.input_bands for b in self.dataset_bands], dtype=torch.bool + ) + self.avail_bands_mask = torch.tensor( + [b in self.dataset_bands for b in self.input_bands], dtype=torch.bool + ) + self.avail_bands_indices = torch.tensor( + [ + self.dataset_bands.index(b) if b in self.dataset_bands else -1 + for b in self.input_bands + ], + dtype=torch.long, + ) self.need_padded = self.avail_bands_mask.sum() < len(self.input_bands) @@ -185,10 +191,18 @@ def __init__(self, cfg, modality): ) def preprocess_band_statistics(self, data_mean, data_std, data_min, data_max): - data_mean = [data_mean[i] if i != -1 else 0.0 for i in self.avail_bands_indices.tolist()] - data_std = [data_std[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist()] - data_min = [data_min[i] if i != -1 else -1.0 for i in self.avail_bands_indices.tolist()] - data_max = [data_max[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist()] + data_mean = [ + data_mean[i] if i != -1 else 0.0 for i in self.avail_bands_indices.tolist() + ] + data_std = [ + data_std[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() + ] + data_min = [ + data_min[i] if i != -1 else -1.0 for i in self.avail_bands_indices.tolist() + ] + data_max = [ + data_max[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() + ] return data_mean, data_std, data_min, data_max def preprocess_single_timeframe(self, image): @@ -312,19 +326,19 @@ def __getitem__(self, index): # Ignore overlapping borders if h_index != 0: - tiled_data["target"][ - ..., 0:h_label_offset, : - ] = self.dataset_cfg.ignore_index + tiled_data["target"][..., 0:h_label_offset, :] = ( + self.dataset_cfg.ignore_index + ) if w_index != 0: tiled_data["target"][..., 0:w_label_offset] = self.dataset_cfg.ignore_index if h_index != self.tiles_per_dim - 1: - tiled_data["target"][ - ..., self.output_size - h_label_offset :, : - ] = self.dataset_cfg.ignore_index + tiled_data["target"][..., self.output_size - h_label_offset :, :] = ( + self.dataset_cfg.ignore_index + ) if w_index != self.tiles_per_dim - 1: - tiled_data["target"][ - ..., self.output_size - w_label_offset : - ] = self.dataset_cfg.ignore_index + tiled_data["target"][..., self.output_size - w_label_offset :] = ( + self.dataset_cfg.ignore_index + ) return tiled_data diff --git a/utils/compute_data_statistics.py b/utils/compute_data_statistics.py index d99c439c..887130a7 100644 --- a/utils/compute_data_statistics.py +++ b/utils/compute_data_statistics.py @@ -1,10 +1,38 @@ +from utils import registry import utils.registry import omegaconf import numpy as np import tqdm -import datasets import torch -import pprint + + +class RunningStats: + def __init__(self, stats_dim): + self.n = 0 + self.sum = torch.zeros(stats_dim) + self.sum_2 = torch.zeros(stats_dim) + + self.min = 10e10 * torch.ones(stats_dim) + self.max = -10e10 * torch.ones(stats_dim) + + def update(self, x, reduce_dim): + self.n += np.prod([x.shape[i] for i in reduce_dim]) + self.sum += torch.sum(x, reduce_dim) + self.sum_2 += torch.sum(x**2, reduce_dim) + + x_min = torch.amin(x, reduce_dim) + x_max = torch.amax(x, reduce_dim) + self.min = torch.min(self.min, x_min) + self.max = torch.max(self.max, x_max) + + def finalize(self): + return { + "mean": self.sum / self.n, + "std": torch.sqrt(self.sum_2 / self.n - (self.sum / self.n) ** 2), + "min": self.min, + "max": self.max, + } + configs = [ "configs/datasets/mados.yaml", @@ -17,28 +45,23 @@ dataset = utils.registry.DATASET_REGISTRY.get(cfg.dataset_name) dataset.download(cfg, silent=False) train_dataset, val_dataset, test_dataset = dataset.get_splits(cfg) + stats = {} + data = train_dataset.__getitem__(0) - min = {} - max = {} + # STATS initialization + stats = {} + for modality, img in data["image"].items(): + n_channels = img.shape[0] + stats[modality] = RunningStats(n_channels) + # STATS computation for data in tqdm.tqdm(train_dataset, desc=cfg.dataset_name): - for modality, img in data['image'].items(): - dims = [i for i in range(len(img.shape))] - dims.pop(-3) - img = torch.nan_to_num(img) - local_max = torch.amax(img, dim=dims) - local_min = torch.amin(img, dim=dims) - - if min.get(modality, None) is None: - print(modality, local_min.shape) - min[modality] = torch.full_like(local_min, 10e10) - max[modality] = torch.full_like(local_max, -10e10) - - min[modality] = torch.minimum(min[modality], local_min) - max[modality] = torch.maximum(max[modality], local_max) - - pprint.pp(cfg.dataset_name) - pprint.pp({ - "max": max, - "min": min - }) + for modality, img in data["image"].items(): + reduce_dim = list(range(1, img.ndim)) + stats[modality].update(img, reduce_dim) + + # STATS finalization + for modality, stat in stats.items(): + print(modality) + print(stat.finalize()) + print("_" * 100)