From f1fb5562b9dc7b99c7c8ccfdb669f0488da09214 Mon Sep 17 00:00:00 2001 From: valhassan Date: Tue, 12 Nov 2024 13:46:13 -0500 Subject: [PATCH] added denormalization function in utils.py --- geo_deep_learning/tools/utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/geo_deep_learning/tools/utils.py b/geo_deep_learning/tools/utils.py index 6eb687e4..be6169e9 100644 --- a/geo_deep_learning/tools/utils.py +++ b/geo_deep_learning/tools/utils.py @@ -13,4 +13,19 @@ def standardization(input_tensor, mean, std): input_tensor = (input_tensor - mean) / std input_tensor = input_tensor.reshape(input_shape) return input_tensor - \ No newline at end of file + +def denormalization(image, mean, std, data_type_max): + if mean is not None and std is not None: + if not torch.is_tensor(mean): + mean = torch.tensor(mean, device=image.device) + if not torch.is_tensor(std): + std = torch.tensor(std, device=image.device) + + mean = mean.reshape(-1, 1, 1) + std = std.reshape(-1, 1, 1) + + image = image * std + mean + + image = (image * data_type_max).clamp(0, data_type_max).to(torch.uint8) + + return image \ No newline at end of file