diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 54661ef8..e20acf3f 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -70,7 +70,7 @@ def forward(self, *features): class DeepLabV3PlusDecoder(nn.Module): def __init__( self, - encoder_channels: Sequence[int, ...], + encoder_channels: Sequence[int], encoder_depth: Literal[3, 4, 5], out_channels: int, atrous_rates: Iterable[int], @@ -79,7 +79,11 @@ def __init__( aspp_dropout: float, ): super().__init__() - if output_stride not in {8, 16}: + if encoder_depth not in (3, 4, 5): + raise ValueError( + "Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth) + ) + if output_stride not in (8, 16): raise ValueError( "Output stride should be 8 or 16, got {}.".format(output_stride) )