From a938f8c8ea937d8de2abdec48716e4478ef49865 Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:27:52 -0800 Subject: [PATCH] [brief] Updates the unit tests to ensure the MAE metrics are tested properly. [detailed] --- test/test_metrics.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/test_metrics.py b/test/test_metrics.py index bada9dd..8f7842f 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -62,15 +62,21 @@ def test_mAP(self) -> None: mAP = metrics.CalculateMAP() self.check_almost_equal(mAP(targs, preds), functional.calculate_mAP(targs, preds)) - def test_MAE(self) -> None: + def check_mae(self, scale: float) -> None: rng.seed_rngs() - pred = torch.rand((32, 32)) - gt = torch.rand((32, 32)) - mae = metrics.CalculateMAE() + pred = torch.randn((1, 3, 32, 32)) * scale + gt = torch.randn((1, 3, 32, 32)) * scale + mae = metrics.CalculateMAE(scale) - self.check_almost_equal(mae(pred, gt), functional.calculate_mae_torch(pred, gt)) + self.check_almost_equal( + mae(pred, gt), functional.calculate_mae_torch(pred, gt, scale) + ) pred = pred.numpy() # type: ignore[assignment] gt = gt.numpy() # type: ignore[assignment] - self.check_almost_equal(mae(pred, gt), functional.calculate_mae(pred, gt)) # type: ignore[arg-type] + self.check_almost_equal(mae(pred, gt), functional.calculate_mae(pred, gt, scale)) # type: ignore[arg-type] + + def test_MAE(self) -> None: + self.check_mae(1.0) + self.check_mae(255.0)