diff --git a/pangaea/encoders/ssl4eo_mae_encoder.py b/pangaea/encoders/ssl4eo_mae_encoder.py index 68bbc1b..162158a 100644 --- a/pangaea/encoders/ssl4eo_mae_encoder.py +++ b/pangaea/encoders/ssl4eo_mae_encoder.py @@ -158,7 +158,7 @@ def forward(self, image): def load_encoder_weights(self, logger: Logger) -> None: checkpoint = torch.load(self.encoder_weights, map_location="cpu") - pretrained_model = checkpoint["model"] + pretrained_model = checkpoint["state_dict"] k = pretrained_model.keys() pretrained_encoder = {}