Skip to content

Commit

Permalink
modify type hint and value check
Browse files Browse the repository at this point in the history
  • Loading branch information
munehiro-k committed Nov 25, 2024
1 parent 2efd974 commit b39b8f3
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, *features):
class DeepLabV3PlusDecoder(nn.Module):
def __init__(
self,
encoder_channels: Sequence[int, ...],
encoder_channels: Sequence[int],
encoder_depth: Literal[3, 4, 5],
out_channels: int,
atrous_rates: Iterable[int],
Expand All @@ -79,7 +79,11 @@ def __init__(
aspp_dropout: float,
):
super().__init__()
if output_stride not in {8, 16}:
if encoder_depth not in (3, 4, 5):
raise ValueError(
"Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth)
)
if output_stride not in (8, 16):
raise ValueError(
"Output stride should be 8 or 16, got {}.".format(output_stride)
)
Expand Down

0 comments on commit b39b8f3

Please sign in to comment.