Skip to content

Commit

Permalink
Expose timm constructor arguments (#960)
Browse files Browse the repository at this point in the history
* Expose timm constructor arguments

* Remove leak from other branch

* Rename dupls to duplicates
  • Loading branch information
DimitrisMantas authored Nov 6, 2024
1 parent b90b3c5 commit cd5d3c2
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 36 deletions.
15 changes: 12 additions & 3 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from typing import Any, Optional

from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder


Expand Down Expand Up @@ -36,6 +37,8 @@ class DeepLabV3(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **DeepLabV3**
Expand All @@ -55,6 +58,7 @@ def __init__(
activation: Optional[str] = None,
upsampling: int = 8,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -64,6 +68,7 @@ def __init__(
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
**kwargs,
)

self.decoder = DeepLabV3Decoder(
Expand Down Expand Up @@ -116,6 +121,8 @@ class DeepLabV3Plus(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **DeepLabV3Plus**
Expand All @@ -137,6 +144,7 @@ def __init__(
activation: Optional[str] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -153,6 +161,7 @@ def __init__(
depth=encoder_depth,
weights=encoder_weights,
output_stride=encoder_output_stride,
**kwargs,
)

self.decoder = DeepLabV3PlusDecoder(
Expand Down
10 changes: 7 additions & 3 deletions segmentation_models_pytorch/decoders/fpn/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from typing import Any, Optional

from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import FPNDecoder


Expand Down Expand Up @@ -40,6 +41,7 @@ class FPN(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **FPN**
Expand All @@ -63,6 +65,7 @@ def __init__(
activation: Optional[str] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -77,6 +80,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = FPNDecoder(
Expand Down
8 changes: 6 additions & 2 deletions segmentation_models_pytorch/decoders/linknet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.base import (
ClassificationHead,
SegmentationHead,
SegmentationModel,
ClassificationHead,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import LinknetDecoder


Expand Down Expand Up @@ -43,6 +44,7 @@ class Linknet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **Linknet**
Expand All @@ -61,6 +63,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -74,6 +77,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = LinknetDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/manet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union, List
from typing import Any, List, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import MAnetDecoder


Expand Down Expand Up @@ -45,6 +46,7 @@ class MAnet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **MAnet**
Expand All @@ -66,6 +68,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -74,6 +77,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = MAnetDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/pan/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import PANDecoder


Expand Down Expand Up @@ -38,6 +39,7 @@ class PAN(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **PAN**
Expand All @@ -58,6 +60,7 @@ def __init__(
activation: Optional[Union[str, callable]] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -74,6 +77,7 @@ def __init__(
depth=5,
weights=encoder_weights,
output_stride=encoder_output_stride,
**kwargs,
)

self.decoder = PANDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/pspnet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import PSPDecoder


Expand Down Expand Up @@ -44,6 +45,7 @@ class PSPNet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **PSPNet**
Expand All @@ -65,6 +67,7 @@ def __init__(
activation: Optional[Union[str, callable]] = None,
upsampling: int = 8,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -73,6 +76,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = PSPDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union, List
from typing import Any, List, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import UnetDecoder


Expand Down Expand Up @@ -44,6 +45,7 @@ class Unet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: Unet
Expand All @@ -65,6 +67,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -73,6 +76,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = UnetDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/unetplusplus/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union, List
from typing import Any, List, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import UnetPlusPlusDecoder


Expand Down Expand Up @@ -44,6 +45,7 @@ class UnetPlusPlus(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **Unet++**
Expand All @@ -65,6 +67,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -78,6 +81,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = UnetPlusPlusDecoder(
Expand Down
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/decoders/upernet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import UPerNetDecoder


Expand Down Expand Up @@ -36,6 +37,7 @@ class UPerNet(SegmentationModel):
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **UPerNet**
Expand All @@ -56,6 +58,7 @@ def __init__(
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

Expand All @@ -64,6 +67,7 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = UPerNetDecoder(
Expand Down
Loading

0 comments on commit cd5d3c2

Please sign in to comment.