Skip to content

Commit

Permalink
Moves spectralgpt compatibility ensurance to config-level preprocessi…
Browse files Browse the repository at this point in the history
…ng in utils/configs.py
  • Loading branch information
SebastianGer committed Sep 17, 2024
1 parent 8d172f8 commit 110f46c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
5 changes: 2 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import utils.losses
from utils.utils import fix_seed, get_generator, seed_worker, prepare_input
from utils.logger import init_logger
from utils.configs import load_configs
from utils.configs import load_configs, ensure_compatible_configs
from utils.registry import (
ENCODER_REGISTRY,
SEGMENTOR_REGISTRY,
Expand Down Expand Up @@ -120,6 +120,7 @@

def main():
cfg = load_configs(parser)
cfg = ensure_compatible_configs(cfg)

# fix all random seeds
fix_seed(cfg.seed)
Expand Down Expand Up @@ -204,8 +205,6 @@ def main():

# prepare the foundation model
download_model(cfg.encoder)
if cfg.encoder.encoder_name == "SpectralGPT_Encoder" and cfg.segmentor.segmentor_name == "UPerNetCD":
cfg.encoder.multi_temporal=1
encoder = ENCODER_REGISTRY.get(cfg.encoder.encoder_name)(
cfg.encoder, **cfg.encoder.encoder_model_args
)
Expand Down
8 changes: 8 additions & 0 deletions utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,11 @@ def _nest(
nested[key] = val

return dict(nested) if nested else None


def ensure_compatible_configs(cfg:OmegaConf) -> OmegaConf:
# SpectralGPT_Encoder can handle multi-temporal input, but in change detection, we encode each time step separately,
# to then compute the change from the different feature representations.
if cfg.encoder.encoder_name == "SpectralGPT_Encoder" and cfg.segmentor.task_name == "change-detection":
cfg.encoder.multi_temporal=1
return cfg

0 comments on commit 110f46c

Please sign in to comment.