Skip to content

Commit

Permalink
Update SegFormer initialization to include weights parameter for encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Nov 27, 2024
1 parent bb4fff6 commit 278e76c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions geo_deep_learning/models/segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def forward(self, x):


class SegFormer(nn.Module):
def __init__(self, encoder, in_channels, num_classes) -> None:
def __init__(self, encoder, in_channels, weights, num_classes) -> None:
super().__init__()
self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels, depth=5, drop_path_rate=0.1)
self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels,
depth=5, weights=weights, drop_path_rate=0.1)
self.decoder = Decoder(encoder=encoder, num_classes=num_classes)

def forward(self, img):
Expand Down

0 comments on commit 278e76c

Please sign in to comment.