From 792e24c053f46bd68218d00b7d7c6db457fb1a1b Mon Sep 17 00:00:00 2001 From: Ritu Yadav <40523539+RituYadav92@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:51:30 +0100 Subject: [PATCH] Reg metric fix (#118) * Update evaluator.py - Fixed the dimension error - Fixed the metric calculation * Update evaluator.py * updated sar band names --- configs/dataset/biomassters.yaml | 8 ++++---- pangaea/engine/evaluator.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/configs/dataset/biomassters.yaml b/configs/dataset/biomassters.yaml index c4e5561c..d98014a2 100644 --- a/configs/dataset/biomassters.yaml +++ b/configs/dataset/biomassters.yaml @@ -4,7 +4,7 @@ root_path: ./data/Biomassters download_url: auto_download: False img_size: 256 -temp: 6 #6 (select month to use if single temporal (multi_temp : 1)) +temp: 6 #6 (select month to use if single temporal (multi_temporal : 1)) multi_temporal: 12 multi_modal: True @@ -35,12 +35,12 @@ bands: - B12 - CLP sar: - - ASC_VV - - ASC_VH + - VV #set band name to match the input band name of the model e.g. VV for CROMA, ASC_VV for DOFA + - VH #set band name to match the input band name of the model e.g. VH for CROMA, ASC_VH for DOFA - DSC_VV - DSC_VH -# TODO: fix the normalization +# TODO: add mean and std normalization values data_mean: optical: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] sar: [0, 0, 0, 0] diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index 0455c4b7..d3c770cb 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -372,16 +372,16 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): if self.inference_mode == "sliding": input_size = model.module.encoder.input_size logits = self.sliding_inference(model, image, input_size, output_shape=target.shape[-2:], - max_batch=self.sliding_inference_batch) + max_batch=self.sliding_inference_batch).squeeze(dim=1) elif self.inference_mode == "whole": logits = model(image, output_shape=target.shape[-2:]).squeeze(dim=1) else: raise NotImplementedError((f"Inference mode {self.inference_mode} is not implemented.")) - mse += F.mse_loss(logits, target, reduction='sum') + mse += F.mse_loss(logits, target) torch.distributed.all_reduce(mse, op=torch.distributed.ReduceOp.SUM) - mse = mse / len(self.val_loader.dataset) + mse = mse / len(self.val_loader) metrics = {"MSE": mse.item(), "RMSE": torch.sqrt(mse).item()} self.log_metrics(metrics)