diff --git a/pangaea/decoders/upernet.py b/pangaea/decoders/upernet.py index df74fef..516a68d 100644 --- a/pangaea/decoders/upernet.py +++ b/pangaea/decoders/upernet.py @@ -302,7 +302,7 @@ def __init__( def get_decoder_in_channels( self, multi_temporal_strategy: str | None, encoder: Encoder ) -> list[int]: - if multi_temporal_strategy == "ltae": + if multi_temporal_strategy == "ltae" and encoder.multi_temporal_output: # if the encoder output channels vary we must use an adaptor before the LTAE ltae_in_channels = max(encoder.output_dim) if ltae_in_channels != min(encoder.output_dim): @@ -728,7 +728,7 @@ def __init__( def get_decoder_in_channels( self, multi_temporal_strategy: str | None, encoder: Encoder ) -> list[int]: - if multi_temporal_strategy == "ltae": + if multi_temporal_strategy == "ltae" and encoder.multi_temporal_output: # if the encoder output channels vary we must use an adaptor before the LTAE ltae_in_channels = max(encoder.output_dim) if ltae_in_channels != min(encoder.output_dim):