From 5a1fae9dce6d8131da9dd44499134ad0b5191bdf Mon Sep 17 00:00:00 2001 From: Yuru Jia <91590963+yurujaja@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:11:53 +0100 Subject: [PATCH] update vit_scratch config (#122) --- configs/encoder/vit_scratch.yaml | 21 +++++++++++++++++++++ pangaea/encoders/vit_encoder.py | 9 +++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 configs/encoder/vit_scratch.yaml diff --git a/configs/encoder/vit_scratch.yaml b/configs/encoder/vit_scratch.yaml new file mode 100644 index 0000000..de0c30a --- /dev/null +++ b/configs/encoder/vit_scratch.yaml @@ -0,0 +1,21 @@ +_target_: pangaea.encoders.vit_encoder.VIT_Encoder +encoder_weights: null +download_url: null + +embed_dim: 768 +input_size: 224 +patch_size: 16 +depth: 12 +num_heads: 12 +mlp_ratio: 4 + +input_bands: ${dataset.bands} + + +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 768 diff --git a/pangaea/encoders/vit_encoder.py b/pangaea/encoders/vit_encoder.py index f22e4f6..9dfce57 100644 --- a/pangaea/encoders/vit_encoder.py +++ b/pangaea/encoders/vit_encoder.py @@ -56,8 +56,9 @@ def __init__( ) self.patch_size = patch_size + self.in_chans = len(input_bands["optical"]) self.patch_embed = PatchEmbed( - input_size, patch_size, in_chans=3, embed_dim=embed_dim + input_size, patch_size, in_chans=self.in_chans, embed_dim=embed_dim ) num_patches = self.patch_embed.num_patches @@ -113,6 +114,8 @@ def forward(self, images): return output def load_encoder_weights(self, logger: Logger) -> None: + if self.encoder_weights is None: + return pretrained_model = torch.load(self.encoder_weights, map_location="cpu") k = pretrained_model.keys() pretrained_encoder = {} @@ -167,7 +170,7 @@ def __init__( ) self.patch_size = patch_size - self.in_channels = 3 + self.in_channels = len(input_bands["optical"]) self.patch_embed = PatchEmbed( input_size, patch_size, in_chans=self.in_channels, embed_dim=embed_dim ) @@ -233,6 +236,8 @@ def forward(self, images): return output def load_encoder_weights(self, logger: Logger) -> None: + if self.encoder_weights is None: + return pretrained_model = torch.load(self.encoder_weights, map_location="cpu") k = pretrained_model.keys() pretrained_encoder = {}