Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose timm constructor arguments #960

Merged
merged 3 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from typing import Any, Optional

from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder


Expand Down Expand Up @@ -36,6 +37,8 @@ class DeepLabV3(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **DeepLabV3**

Expand All @@ -55,6 +58,7 @@ def __init__(
activation: Optional[str] = None,
upsampling: int = 8,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -64,6 +68,7 @@ def __init__(
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
**kwargs,
)

self.decoder = DeepLabV3Decoder(
Expand Down Expand Up @@ -116,6 +121,8 @@ class DeepLabV3Plus(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **DeepLabV3Plus**

Expand All @@ -137,6 +144,7 @@ def __init__(
activation: Optional[str] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -153,6 +161,7 @@ def __init__(
depth=encoder_depth,
weights=encoder_weights,
output_stride=encoder_output_stride,
**kwargs,
)

self.decoder = DeepLabV3PlusDecoder(
Expand Down
10 changes: 7 additions & 3 deletions segmentation_models_pytorch/decoders/fpn/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from typing import Any, Optional

from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import FPNDecoder


Expand Down Expand Up @@ -40,6 +41,7 @@ class FPN(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **FPN**
Expand All @@ -63,6 +65,7 @@ def __init__(
activation: Optional[str] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -77,6 +80,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = FPNDecoder(
Expand Down
8 changes: 6 additions & 2 deletions segmentation_models_pytorch/decoders/linknet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.base import (
ClassificationHead,
SegmentationHead,
SegmentationModel,
ClassificationHead,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import LinknetDecoder


Expand Down Expand Up @@ -43,6 +44,7 @@ class Linknet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **Linknet**
Expand All @@ -61,6 +63,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -74,6 +77,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = LinknetDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/manet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union, List
from typing import Any, List, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import MAnetDecoder


Expand Down Expand Up @@ -45,6 +46,7 @@ class MAnet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **MAnet**
Expand All @@ -66,6 +68,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -74,6 +77,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = MAnetDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/pan/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import PANDecoder


Expand Down Expand Up @@ -38,6 +39,7 @@ class PAN(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **PAN**
Expand All @@ -58,6 +60,7 @@ def __init__(
activation: Optional[Union[str, callable]] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -74,6 +77,7 @@ def __init__(
depth=5,
weights=encoder_weights,
output_stride=encoder_output_stride,
**kwargs,
)

self.decoder = PANDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/pspnet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import PSPDecoder


Expand Down Expand Up @@ -44,6 +45,7 @@ class PSPNet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **PSPNet**
Expand All @@ -65,6 +67,7 @@ def __init__(
activation: Optional[Union[str, callable]] = None,
upsampling: int = 8,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -73,6 +76,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = PSPDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union, List
from typing import Any, List, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import UnetDecoder


Expand Down Expand Up @@ -44,6 +45,7 @@ class Unet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: Unet
Expand All @@ -65,6 +67,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -73,6 +76,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = UnetDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/unetplusplus/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union, List
from typing import Any, List, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import UnetPlusPlusDecoder


Expand Down Expand Up @@ -44,6 +45,7 @@ class UnetPlusPlus(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **Unet++**
Expand All @@ -65,6 +67,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -78,6 +81,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = UnetPlusPlusDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/upernet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import UPerNetDecoder


Expand Down Expand Up @@ -36,6 +37,7 @@ class UPerNet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **UPerNet**
Expand All @@ -56,6 +58,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -64,6 +67,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = UPerNetDecoder(
Expand Down
Loading