Skip to content

Commit

Permalink
added denormalization function in utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Nov 12, 2024
1 parent 84cc762 commit f1fb556
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion geo_deep_learning/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,19 @@ def standardization(input_tensor, mean, std):
input_tensor = (input_tensor - mean) / std
input_tensor = input_tensor.reshape(input_shape)
return input_tensor


def denormalization(image, mean, std, data_type_max):
if mean is not None and std is not None:
if not torch.is_tensor(mean):
mean = torch.tensor(mean, device=image.device)
if not torch.is_tensor(std):
std = torch.tensor(std, device=image.device)

mean = mean.reshape(-1, 1, 1)
std = std.reshape(-1, 1, 1)

image = image * std + mean

image = (image * data_type_max).clamp(0, data_type_max).to(torch.uint8)

return image

0 comments on commit f1fb556

Please sign in to comment.