diff --git a/geo_deep_learning/tools/utils.py b/geo_deep_learning/tools/utils.py index 6eb687e4..be6169e9 100644 --- a/geo_deep_learning/tools/utils.py +++ b/geo_deep_learning/tools/utils.py @@ -13,4 +13,19 @@ def standardization(input_tensor, mean, std): input_tensor = (input_tensor - mean) / std input_tensor = input_tensor.reshape(input_shape) return input_tensor - \ No newline at end of file + +def denormalization(image, mean, std, data_type_max): + if mean is not None and std is not None: + if not torch.is_tensor(mean): + mean = torch.tensor(mean, device=image.device) + if not torch.is_tensor(std): + std = torch.tensor(std, device=image.device) + + mean = mean.reshape(-1, 1, 1) + std = std.reshape(-1, 1, 1) + + image = image * std + mean + + image = (image * data_type_max).clamp(0, data_type_max).to(torch.uint8) + + return image \ No newline at end of file