diff --git a/pangaea/datasets/sen1floods11.py b/pangaea/datasets/sen1floods11.py index 235a186..66baa49 100644 --- a/pangaea/datasets/sen1floods11.py +++ b/pangaea/datasets/sen1floods11.py @@ -14,24 +14,8 @@ class Sen1Floods11(RawGeoFMDataset): def __init__( self, - split: str, - dataset_name: str, - multi_modal: bool, - multi_temporal: int, - root_path: str, - classes: list, - num_classes: int, - ignore_index: int, - img_size: int, - bands: dict[str, list[str]], - distribution: list[int], - data_mean: dict[str, list[str]], - data_std: dict[str, list[str]], - data_min: dict[str, list[str]], - data_max: dict[str, list[str]], - download_url: str, - auto_download: bool, - gcs_bucket: str, + gcs_bucket: str, + **kwargs ): """Initialize the Sen1Floods11 dataset. Link: https://github.com/cloudtostreet/Sen1Floods11 @@ -67,45 +51,12 @@ def __init__( self.gcs_bucket = gcs_bucket - super(Sen1Floods11, self).__init__( - split=split, - dataset_name=dataset_name, - multi_modal=multi_modal, - multi_temporal=multi_temporal, - root_path=root_path, - classes=classes, - num_classes=num_classes, - ignore_index=ignore_index, - img_size=img_size, - bands=bands, - distribution=distribution, - data_mean=data_mean, - data_std=data_std, - data_min=data_min, - data_max=data_max, - download_url=download_url, - auto_download=auto_download, - ) - - self.root_path = root_path - self.classes = classes - self.split = split - - self.data_mean = data_mean - self.data_std = data_std - self.data_min = data_min - self.data_max = data_max - self.classes = classes - self.img_size = img_size - self.distribution = distribution - self.num_classes = num_classes - self.ignore_index = ignore_index - self.download_url = download_url - self.auto_download = auto_download + super(Sen1Floods11, self).__init__(**kwargs) + self.split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} - split_file = os.path.join(self.root_path, "v1.1", f"splits/flood_handlabeled/flood_{self.split_mapping[split]}_data.csv") + split_file = os.path.join(self.root_path, "v1.1", f"splits/flood_handlabeled/flood_{self.split_mapping[self.split]}_data.csv") metadata_file = os.path.join(self.root_path, "v1.1", "Sen1Floods11_Metadata.geojson") data_root = os.path.join(self.root_path, "v1.1", "data/flood_events/HandLabeled/") @@ -153,18 +104,19 @@ def __getitem__(self, index): s2_image = torch.from_numpy(s2_image).float() s1_image = torch.from_numpy(s1_image).float() - target = torch.from_numpy(target) + target = torch.from_numpy(target).long() output = { 'image': { - 'optical': s2_image, - 'sar' : s1_image, + 'optical': s2_image.unsqueeze(1), + 'sar' : s1_image.unsqueeze(1), }, 'target': target, 'metadata': { "timestamp": timestamp, } } + return output @staticmethod diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index b257827..5e6ec7e 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -110,7 +110,7 @@ def sliding_inference(model, img, input_size, output_shape=None, stride=None, ma pred_ = model.forward(img_, output_shape=(input_size, input_size)) pred.append(pred_) pred = torch.cat(pred, dim=0) - pred = pred.view(b, num_crops_per_img, -1, input_size, input_size) + pred = pred.view(num_crops_per_img, b, -1, input_size, input_size).transpose(0, 1) merged_pred = torch.zeros((b, pred.shape[2], height, width), device=pred.device) pred_count = torch.zeros((b, height, width), dtype=torch.long, device=pred.device) @@ -401,4 +401,4 @@ def log_metrics(self, metrics): self.logger.info(header + mse + rmse) if self.use_wandb: - wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) \ No newline at end of file + wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) diff --git a/pretrained_models b/pretrained_models new file mode 120000 index 0000000..4b70399 --- /dev/null +++ b/pretrained_models @@ -0,0 +1 @@ +../workspace/pangaea-bench/pretrained_models/ \ No newline at end of file