Skip to content

Commit

Permalink
fixed evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
LeungTsang committed Oct 13, 2024
1 parent 8c3a376 commit fc50576
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 59 deletions.
66 changes: 9 additions & 57 deletions pangaea/datasets/sen1floods11.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,8 @@ class Sen1Floods11(RawGeoFMDataset):

def __init__(
self,
split: str,
dataset_name: str,
multi_modal: bool,
multi_temporal: int,
root_path: str,
classes: list,
num_classes: int,
ignore_index: int,
img_size: int,
bands: dict[str, list[str]],
distribution: list[int],
data_mean: dict[str, list[str]],
data_std: dict[str, list[str]],
data_min: dict[str, list[str]],
data_max: dict[str, list[str]],
download_url: str,
auto_download: bool,
gcs_bucket: str,
gcs_bucket: str,
**kwargs
):
"""Initialize the Sen1Floods11 dataset.
Link: https://github.com/cloudtostreet/Sen1Floods11
Expand Down Expand Up @@ -67,45 +51,12 @@ def __init__(

self.gcs_bucket = gcs_bucket

super(Sen1Floods11, self).__init__(
split=split,
dataset_name=dataset_name,
multi_modal=multi_modal,
multi_temporal=multi_temporal,
root_path=root_path,
classes=classes,
num_classes=num_classes,
ignore_index=ignore_index,
img_size=img_size,
bands=bands,
distribution=distribution,
data_mean=data_mean,
data_std=data_std,
data_min=data_min,
data_max=data_max,
download_url=download_url,
auto_download=auto_download,
)

self.root_path = root_path
self.classes = classes
self.split = split

self.data_mean = data_mean
self.data_std = data_std
self.data_min = data_min
self.data_max = data_max
self.classes = classes
self.img_size = img_size
self.distribution = distribution
self.num_classes = num_classes
self.ignore_index = ignore_index
self.download_url = download_url
self.auto_download = auto_download
super(Sen1Floods11, self).__init__(**kwargs)


self.split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'}

split_file = os.path.join(self.root_path, "v1.1", f"splits/flood_handlabeled/flood_{self.split_mapping[split]}_data.csv")
split_file = os.path.join(self.root_path, "v1.1", f"splits/flood_handlabeled/flood_{self.split_mapping[self.split]}_data.csv")
metadata_file = os.path.join(self.root_path, "v1.1", "Sen1Floods11_Metadata.geojson")
data_root = os.path.join(self.root_path, "v1.1", "data/flood_events/HandLabeled/")

Expand Down Expand Up @@ -153,18 +104,19 @@ def __getitem__(self, index):

s2_image = torch.from_numpy(s2_image).float()
s1_image = torch.from_numpy(s1_image).float()
target = torch.from_numpy(target)
target = torch.from_numpy(target).long()

output = {
'image': {
'optical': s2_image,
'sar' : s1_image,
'optical': s2_image.unsqueeze(1),
'sar' : s1_image.unsqueeze(1),
},
'target': target,
'metadata': {
"timestamp": timestamp,
}
}

return output

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions pangaea/engine/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def sliding_inference(model, img, input_size, output_shape=None, stride=None, ma
pred_ = model.forward(img_, output_shape=(input_size, input_size))
pred.append(pred_)
pred = torch.cat(pred, dim=0)
pred = pred.view(b, num_crops_per_img, -1, input_size, input_size)
pred = pred.view(num_crops_per_img, b, -1, input_size, input_size).transpose(0, 1)

merged_pred = torch.zeros((b, pred.shape[2], height, width), device=pred.device)
pred_count = torch.zeros((b, height, width), dtype=torch.long, device=pred.device)
Expand Down Expand Up @@ -401,4 +401,4 @@ def log_metrics(self, metrics):
self.logger.info(header + mse + rmse)

if self.use_wandb:
wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]})
wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]})
1 change: 1 addition & 0 deletions pretrained_models

0 comments on commit fc50576

Please sign in to comment.