diff --git a/configs/encoder/vit_scratch.yaml b/configs/encoder/vit_scratch.yaml new file mode 100644 index 0000000..de0c30a --- /dev/null +++ b/configs/encoder/vit_scratch.yaml @@ -0,0 +1,21 @@ +_target_: pangaea.encoders.vit_encoder.VIT_Encoder +encoder_weights: null +download_url: null + +embed_dim: 768 +input_size: 224 +patch_size: 16 +depth: 12 +num_heads: 12 +mlp_ratio: 4 + +input_bands: ${dataset.bands} + + +output_layers: + - 3 + - 5 + - 7 + - 11 + +output_dim: 768 diff --git a/pangaea/encoders/vit_encoder.py b/pangaea/encoders/vit_encoder.py index f22e4f6..9dfce57 100644 --- a/pangaea/encoders/vit_encoder.py +++ b/pangaea/encoders/vit_encoder.py @@ -56,8 +56,9 @@ def __init__( ) self.patch_size = patch_size + self.in_chans = len(input_bands["optical"]) self.patch_embed = PatchEmbed( - input_size, patch_size, in_chans=3, embed_dim=embed_dim + input_size, patch_size, in_chans=self.in_chans, embed_dim=embed_dim ) num_patches = self.patch_embed.num_patches @@ -113,6 +114,8 @@ def forward(self, images): return output def load_encoder_weights(self, logger: Logger) -> None: + if self.encoder_weights is None: + return pretrained_model = torch.load(self.encoder_weights, map_location="cpu") k = pretrained_model.keys() pretrained_encoder = {} @@ -167,7 +170,7 @@ def __init__( ) self.patch_size = patch_size - self.in_channels = 3 + self.in_channels = len(input_bands["optical"]) self.patch_embed = PatchEmbed( input_size, patch_size, in_chans=self.in_channels, embed_dim=embed_dim ) @@ -233,6 +236,8 @@ def forward(self, images): return output def load_encoder_weights(self, logger: Logger) -> None: + if self.encoder_weights is None: + return pretrained_model = torch.load(self.encoder_weights, map_location="cpu") k = pretrained_model.keys() pretrained_encoder = {}