From 8f7ac081107aae683b2cf5eb586acd25b96f660b Mon Sep 17 00:00:00 2001 From: KerekesDavid Date: Mon, 28 Oct 2024 12:51:22 +0100 Subject: [PATCH 1/3] Fix evaluator on multi-gpu calls --- pangaea/engine/evaluator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index 5e6ec7e..0455c4b 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -290,7 +290,7 @@ def format_metric(name, values, mean_value): self.logger.info(recall_str) self.logger.info(macc_str) - if self.use_wandb: + if self.use_wandb and self.rank == 0: wandb.log( { f"{self.split}_mIoU": metrics["mIoU"], @@ -400,5 +400,5 @@ def log_metrics(self, metrics): rmse = "-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE']) self.logger.info(header + mse + rmse) - if self.use_wandb: + if self.use_wandb and self.rank == 0: wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) From 8497d6e9d44bfaf541fc0fa1af9499ffc74b99a3 Mon Sep 17 00:00:00 2001 From: KerekesDavid Date: Mon, 28 Oct 2024 14:35:57 +0100 Subject: [PATCH 2/3] Store best checkpoint on disk instead of GPU memory. --- pangaea/engine/trainer.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index 5aebcda..2291f0a 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -76,7 +76,6 @@ def __init__( for name in ["loss", "data_time", "batch_time", "eval_time"] } self.training_metrics = {} - self.best_ckpt = None self.best_metric_comp = operator.gt self.num_classes = self.train_loader.dataset.num_classes @@ -105,7 +104,7 @@ def train(self) -> None: if epoch % self.eval_interval == 0: metrics, used_time = self.evaluator(self.model, f"epoch {epoch}") self.training_stats["eval_time"].update(used_time) - self.set_best_checkpoint(metrics, epoch) + self.save_best_checkpoint(metrics, epoch) self.logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler @@ -117,17 +116,11 @@ def train(self) -> None: metrics, used_time = self.evaluator(self.model, "final model") self.training_stats["eval_time"].update(used_time) - self.set_best_checkpoint(metrics, self.n_epochs) + self.save_best_checkpoint(metrics, self.n_epochs) # save last model self.save_model(self.n_epochs, is_final=True) - # save best model - if self.best_ckpt: - self.save_model( - self.best_ckpt["epoch"], is_best=True, checkpoint=self.best_ckpt - ) - def train_one_epoch(self, epoch: int) -> None: """Train model for one epoch. @@ -186,7 +179,7 @@ def train_one_epoch(self, epoch: int) -> None: end_time = time.time() def get_checkpoint(self, epoch: int) -> dict[str, dict | int]: - """Create a checkpoint dictionary. + """Create a checkpoint dictionary, containing references to the pytorch tensors. Args: epoch (int): number of the epoch. @@ -201,7 +194,7 @@ def get_checkpoint(self, epoch: int) -> dict[str, dict | int]: "scaler": self.scaler.state_dict(), "epoch": epoch, } - return copy.deepcopy(checkpoint) + return checkpoint def save_model( self, @@ -222,8 +215,8 @@ def save_model( torch.distributed.barrier() return checkpoint = self.get_checkpoint(epoch) if checkpoint is None else checkpoint - suffix = "_best" if is_best else "_final" if is_final else "" - checkpoint_path = os.path.join(self.exp_dir, f"checkpoint_{epoch}{suffix}.pth") + suffix = "_best" if is_best else f"{epoch}_final" if is_final else f"{epoch}" + checkpoint_path = os.path.join(self.exp_dir, f"checkpoint_{suffix}.pth") torch.save(checkpoint, checkpoint_path) self.logger.info( f"Epoch {epoch} | Training checkpoint saved at {checkpoint_path}" @@ -267,7 +260,7 @@ def compute_loss(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tens """ raise NotImplementedError - def set_best_checkpoint( + def save_best_checkpoint( self, eval_metrics: dict[float, list[float]], epoch: int ) -> None: """Update the best checkpoint according to the evaluation metrics. @@ -281,7 +274,10 @@ def set_best_checkpoint( curr_metric = curr_metric[0] if self.num_classes == 1 else np.mean(curr_metric) if self.best_metric_comp(curr_metric, self.best_metric): self.best_metric = curr_metric - self.best_ckpt = self.get_checkpoint(epoch) + best_ckpt = self.get_checkpoint(epoch) + self.save_model( + epoch, is_best=True, checkpoint=best_ckpt + ) @torch.no_grad() def compute_logging_metrics( From ecfefa119f3e61b4841955afe302bae961b50c79 Mon Sep 17 00:00:00 2001 From: Georges Le Bellier Date: Thu, 31 Oct 2024 13:45:45 +0100 Subject: [PATCH 3/3] 110 add multi-temporal unet and ViT encoder (#114) * Add multi temporal unet encoder and corresponding config * Add multi temporal ViT --- configs/encoder/unet_encoder_mi.yaml | 15 ++ configs/encoder/vit_mi.yaml | 26 +++ pangaea/encoders/unet_encoder.py | 61 +++++++ pangaea/encoders/vit_encoder.py | 233 +++++++++++++++++++++------ 4 files changed, 288 insertions(+), 47 deletions(-) create mode 100644 configs/encoder/unet_encoder_mi.yaml create mode 100644 configs/encoder/vit_mi.yaml diff --git a/configs/encoder/unet_encoder_mi.yaml b/configs/encoder/unet_encoder_mi.yaml new file mode 100644 index 0000000..4a7ab92 --- /dev/null +++ b/configs/encoder/unet_encoder_mi.yaml @@ -0,0 +1,15 @@ +_target_: pangaea.encoders.unet_encoder.UNetMT +encoder_weights: null +download_url: null +input_size: ${dataset.img_size} +multi_temporal: ${dataset.multi_temporal} +topology: [64, 128, 256, 512,] + +input_bands: ${dataset.bands} + +output_dim: + - 64 + - 128 + - 256 + - 512 + diff --git a/configs/encoder/vit_mi.yaml b/configs/encoder/vit_mi.yaml new file mode 100644 index 0000000..e5b56d9 --- /dev/null +++ b/configs/encoder/vit_mi.yaml @@ -0,0 +1,26 @@ +_target_: pangaea.encoders.vit_encoder.VIT_EncoderMT +encoder_weights: ./pretrained_models/jx_vit_base_p16_224-80ecf9dd.pt +download_url: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth + +multi_temporal: ${dataset.multi_temporal} +embed_dim: 768 +input_size: 224 +patch_size: 16 +depth: 12 +num_heads: 12 +mlp_ratio: 4 + +input_bands: + optical: + - B4 + - B3 + - B2 + + +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 768 diff --git a/pangaea/encoders/unet_encoder.py b/pangaea/encoders/unet_encoder.py index 8281db6..28f7f7b 100644 --- a/pangaea/encoders/unet_encoder.py +++ b/pangaea/encoders/unet_encoder.py @@ -60,6 +60,67 @@ def load_encoder_weights(self, logger: Logger) -> None: pass +class UNetMT(Encoder): + """ + Multi Temporal UNet Encoder for Supervised Baseline, to be trained from scratch. + It supports single time frame inputs with optical bands + + Args: + input_bands (dict[str, list[str]]): Band names, specifically expecting the 'optical' key with a list of bands. + input_size (int): Size of the input images (height and width). + topology (Sequence[int]): The number of feature channels at each stage of the U-Net encoder. + + """ + + def __init__( + self, + input_bands: dict[str, list[str]], + input_size: int, + multi_temporal: int, + topology: Sequence[int], + output_dim: int | list[int], + download_url: str, + encoder_weights: str | None = None, + ): + super().__init__( + model_name="unet_encoder", + encoder_weights=encoder_weights, # no pre-trained weights, train from scratch + input_bands=input_bands, + input_size=input_size, + embed_dim=0, + output_dim=output_dim, + output_layers=None, + multi_temporal=multi_temporal, + multi_temporal_output=False, + pyramid_output=True, + download_url=download_url, + ) + + self.in_channels = len(input_bands["optical"]) # number of optical bands + self.topology = topology + + self.in_conv = InConv(self.in_channels, self.topology[0], DoubleConv) + self.encoder = UNet_Encoder(self.topology) + + self.time_merging = DoubleConv( + in_ch=self.in_channels * self.multi_temporal, out_ch=self.in_channels + ) + + def forward(self, image): + x = image["optical"] + b, c, t, h, w = x.shape + # merge time and channels dimension + x = x.reshape(b, c * t, h, w) + x = self.time_merging(x) + + feat = self.in_conv(x) + output = self.encoder(feat) + return output + + def load_encoder_weights(self, logger: Logger) -> None: + pass + + class UNet_Encoder(nn.Module): """ UNet Encoder class that defines the architecture of the encoder part of the UNet. diff --git a/pangaea/encoders/vit_encoder.py b/pangaea/encoders/vit_encoder.py index 382d676..f22e4f6 100644 --- a/pangaea/encoders/vit_encoder.py +++ b/pangaea/encoders/vit_encoder.py @@ -14,63 +14,79 @@ import torch import torch.nn as nn +from timm.models.vision_transformer import Block, PatchEmbed + +from pangaea.encoders.unet_encoder import DoubleConv -import timm.models.vision_transformer -from timm.models.vision_transformer import PatchEmbed, Block from .base import Encoder class VIT_Encoder(Encoder): - """ Vision Transformer with support for global average pooling - """ - - def __init__(self, - encoder_weights, - input_size, - input_bands, - embed_dim, - output_layers, - output_dim, - download_url, - patch_size=16, - depth=12, - num_heads=12, - mlp_ratio=4, - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6)): - - Encoder.__init__(self, - model_name="vit_encoder", - encoder_weights=encoder_weights, - input_bands=input_bands, - input_size=input_size, - embed_dim=embed_dim, - output_layers=output_layers, - output_dim=output_dim, - multi_temporal=False, - multi_temporal_output=False, - pyramid_output=False, - download_url=download_url - ) + """Vision Transformer with support for global average pooling""" + + def __init__( + self, + encoder_weights, + input_size, + input_bands, + embed_dim, + output_layers, + output_dim, + download_url, + patch_size=16, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ): + Encoder.__init__( + self, + model_name="vit_encoder", + encoder_weights=encoder_weights, + input_bands=input_bands, + input_size=input_size, + embed_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, + multi_temporal=False, + multi_temporal_output=False, + pyramid_output=False, + download_url=download_url, + ) self.patch_size = patch_size - self.patch_embed = PatchEmbed(input_size, patch_size, in_chans=3, embed_dim=embed_dim) + self.patch_embed = PatchEmbed( + input_size, patch_size, in_chans=3, embed_dim=embed_dim + ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding - - self.blocks = nn.ModuleList([ - Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer) - for i in range(depth)]) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) self.norm = norm_layer(embed_dim) - def forward(self, images): x = images["optical"].squeeze(2) x = self.patch_embed(x) - cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand( + x.shape[0], -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed @@ -82,16 +98,139 @@ def forward(self, images): if i in self.output_layers: out = x[:, 1:] - out = out.transpose(1, 2).view( - x.shape[0], - -1, - self.input_size // self.patch_size, - self.input_size // self.patch_size, - ).contiguous() + out = ( + out.transpose(1, 2) + .view( + x.shape[0], + -1, + self.input_size // self.patch_size, + self.input_size // self.patch_size, + ) + .contiguous() + ) output.append(out) return output + def load_encoder_weights(self, logger: Logger) -> None: + pretrained_model = torch.load(self.encoder_weights, map_location="cpu") + k = pretrained_model.keys() + pretrained_encoder = {} + incompatible_shape = {} + missing = {} + for name, param in self.named_parameters(): + if name not in k: + missing[name] = param.shape + elif pretrained_model[name].shape != param.shape: + incompatible_shape[name] = (param.shape, pretrained_model[name].shape) + pretrained_model.pop(name) + else: + pretrained_encoder[name] = pretrained_model.pop(name) + + self.load_state_dict(pretrained_encoder, strict=False) + self.parameters_warning(missing, incompatible_shape, logger) + + +class VIT_EncoderMT(Encoder): + """Vision Transformer with support for global average pooling""" + + def __init__( + self, + encoder_weights, + input_size, + input_bands, + embed_dim, + multi_temporal, + output_layers, + output_dim, + download_url, + patch_size=16, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ): + Encoder.__init__( + self, + model_name="vit_encoder", + encoder_weights=encoder_weights, + input_bands=input_bands, + input_size=input_size, + embed_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, + multi_temporal=multi_temporal, + multi_temporal_output=False, + pyramid_output=False, + download_url=download_url, + ) + + self.patch_size = patch_size + self.in_channels = 3 + self.patch_embed = PatchEmbed( + input_size, patch_size, in_chans=self.in_channels, embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + self.time_merging = DoubleConv( + in_ch=self.in_channels * self.multi_temporal, out_ch=self.in_channels + ) + + def forward(self, images): + x = images["optical"] + b, c, t, h, w = x.shape + # merge time and channels dimension + x = x.reshape(b, c * t, h, w) + x = self.time_merging(x) + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand( + x.shape[0], -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if i == len(self.blocks) - 1: + x = self.norm(x) + + if i in self.output_layers: + out = x[:, 1:] + out = ( + out.transpose(1, 2) + .view( + x.shape[0], + -1, + self.input_size // self.patch_size, + self.input_size // self.patch_size, + ) + .contiguous() + ) + output.append(out) + + return output def load_encoder_weights(self, logger: Logger) -> None: pretrained_model = torch.load(self.encoder_weights, map_location="cpu")