From b39b8f351cda874bee4cd3d4699a2aa7468c1772 Mon Sep 17 00:00:00 2001 From: Munehiro Kobayashi Date: Mon, 25 Nov 2024 10:58:29 +0900 Subject: [PATCH] modify type hint and value check --- segmentation_models_pytorch/decoders/deeplabv3/decoder.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 54661ef8..e20acf3f 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -70,7 +70,7 @@ def forward(self, *features): class DeepLabV3PlusDecoder(nn.Module): def __init__( self, - encoder_channels: Sequence[int, ...], + encoder_channels: Sequence[int], encoder_depth: Literal[3, 4, 5], out_channels: int, atrous_rates: Iterable[int], @@ -79,7 +79,11 @@ def __init__( aspp_dropout: float, ): super().__init__() - if output_stride not in {8, 16}: + if encoder_depth not in (3, 4, 5): + raise ValueError( + "Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth) + ) + if output_stride not in (8, 16): raise ValueError( "Output stride should be 8 or 16, got {}.".format(output_stride) )