Skip to content

Commit

Permalink
Refactor SegFormer class to support dynamic encoder freezing.
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Dec 9, 2024
1 parent a248b53 commit b4c1b43
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion geo_deep_learning/models/segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,23 @@ def forward(self, x):


class SegFormer(nn.Module):
def __init__(self, encoder, in_channels, weights, num_classes) -> None:
def __init__(self,
encoder: str = "mit_b0",
in_channels: int = 3,
weights: str = None,
freeze_encoder: bool = False,
num_classes: int = 1) -> None:
super().__init__()
self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels,
depth=5, weights=weights, drop_path_rate=0.1)
if freeze_encoder:
self._freeze_encoder()
self.encoder.eval()
self.decoder = Decoder(encoder=encoder, num_classes=num_classes)

def _freeze_encoder(self):
for param in self.encoder.parameters():
param.requires_grad = False

def forward(self, img):
# print(f"{__name__}: Input shape: {img.shape}")
Expand Down

0 comments on commit b4c1b43

Please sign in to comment.