diff --git a/geo_deep_learning/models/segformer.py b/geo_deep_learning/models/segformer.py index a59dca37..f5872306 100644 --- a/geo_deep_learning/models/segformer.py +++ b/geo_deep_learning/models/segformer.py @@ -71,9 +71,10 @@ def forward(self, x): class SegFormer(nn.Module): - def __init__(self, encoder, in_channels, num_classes) -> None: + def __init__(self, encoder, in_channels, weights, num_classes) -> None: super().__init__() - self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels, depth=5, drop_path_rate=0.1) + self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels, + depth=5, weights=weights, drop_path_rate=0.1) self.decoder = Decoder(encoder=encoder, num_classes=num_classes) def forward(self, img):