Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DeepLabV3Plus encoder depth #986

Merged
merged 3 commits into from
Nov 29, 2024

Conversation

munehiro-k
Copy link
Contributor

Hi, @qubvel. Thanks for your great work!

This PR fixes #377 and is used to be PR #561, which staled and closed before.

I ran into the issue #377 again to find that the PR can still be effective.
(My typical use case involves processing small images in real time, and a small encoder depth is preferred.)
So, I made some maintenance and rebased on edadc0d.

Update

  1. Fix the feature index mismatch which occurs when encoder_depth is 3 and 4.
    • Please refer to the attached file to see the combination of tensor shapes in each cases: tensor_shapes.md.
  2. Modify the docstring for upsampling argument to state the condition to preserve input-output shape.
    • In case (encoder_depth, encoder_output_stride) = (3, 16), upsampling should be set to 2.
  3. Modify a type hint and add a value check.
    • Type hint error: Squence[int, ...] should be Squence[int].
    • Value check: add a validation to make sure encoder_depth is either 3, 4, or 5.

Test Code

from itertools import product

import torch
import segmentation_models_pytorch as smp

input_shape = (10, 3, 192, 128)
input_tensor = torch.zeros(input_shape)
for up, depth, stride in product((2, 4), (3, 4, 5), (8, 16)):
    net = smp.DeepLabV3Plus(
        encoder_name="timm-mobilenetv3_small_minimal_100",
        encoder_weights="imagenet",
        encoder_depth=depth,
        encoder_output_stride=stride,
        upsampling=up,
        classes=1,
        activation=None
    )
    print(f"encoder_depth={depth}, encoder_output_stride={stride:2}, upsampling={up}")
    output_shape = tuple(net(input_tensor).shape)
    preserved = all(input_shape[i] == output_shape[i] for i in (0, 2, 3))
    print(f"  output shape: {output_shape}, preserve shape: {preserved}")

output

encoder_depth=3, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=3, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=4, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=4, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=5, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=5, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=3, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=3, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 384, 256), preserve shape: False
encoder_depth=4, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=4, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=5, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=5, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True

@brianhou0208
Copy link
Contributor

brianhou0208 commented Nov 29, 2024

Hi @munehiro-k ,

I also found this problem. I think I can contribute another PR and solve problems about DeeplabV3 and Deeplab發V3+ in different encoder depth and output stride

@qubvel qubvel self-requested a review November 29, 2024 18:01
Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, it looks great to me! Special thanks for providing the testing code sample. 🤗

@qubvel qubvel merged commit cc482aa into qubvel-org:main Nov 29, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

encoder_depth error
3 participants