Skip to content

Commit

Permalink
Merge branch 'main' of github.com:yurujaja/geofm-bench
Browse files Browse the repository at this point in the history
  • Loading branch information
hfangcat committed Dec 4, 2024
2 parents 3d0fce7 + 5a1fae9 commit 73848fa
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
5 changes: 5 additions & 0 deletions configs/decoder/reg_unet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: pangaea.decoders.unet.UNet

encoder: null
num_classes: 1
finetune: True # ${finetune}
21 changes: 21 additions & 0 deletions configs/encoder/vit_scratch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_target_: pangaea.encoders.vit_encoder.VIT_Encoder
encoder_weights: null
download_url: null

embed_dim: 768
input_size: 224
patch_size: 16
depth: 12
num_heads: 12
mlp_ratio: 4

input_bands: ${dataset.bands}


output_layers:
- 3
- 5
- 7
- 11

output_dim: 768
9 changes: 7 additions & 2 deletions pangaea/encoders/vit_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def __init__(
)

self.patch_size = patch_size
self.in_chans = len(input_bands["optical"])
self.patch_embed = PatchEmbed(
input_size, patch_size, in_chans=3, embed_dim=embed_dim
input_size, patch_size, in_chans=self.in_chans, embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches

Expand Down Expand Up @@ -113,6 +114,8 @@ def forward(self, images):
return output

def load_encoder_weights(self, logger: Logger) -> None:
if self.encoder_weights is None:
return
pretrained_model = torch.load(self.encoder_weights, map_location="cpu")
k = pretrained_model.keys()
pretrained_encoder = {}
Expand Down Expand Up @@ -167,7 +170,7 @@ def __init__(
)

self.patch_size = patch_size
self.in_channels = 3
self.in_channels = len(input_bands["optical"])
self.patch_embed = PatchEmbed(
input_size, patch_size, in_chans=self.in_channels, embed_dim=embed_dim
)
Expand Down Expand Up @@ -233,6 +236,8 @@ def forward(self, images):
return output

def load_encoder_weights(self, logger: Logger) -> None:
if self.encoder_weights is None:
return
pretrained_model = torch.load(self.encoder_weights, map_location="cpu")
k = pretrained_model.keys()
pretrained_encoder = {}
Expand Down

0 comments on commit 73848fa

Please sign in to comment.