diff --git a/geo_deep_learning/models/dofa/dofa_seg.py b/geo_deep_learning/models/dofa/dofa_seg.py index 8190b07f..2723d2dd 100644 --- a/geo_deep_learning/models/dofa/dofa_seg.py +++ b/geo_deep_learning/models/dofa/dofa_seg.py @@ -12,6 +12,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init +from neckhead import MultiLevelNeck from pathlib import Path from timm.models.vision_transformer import Block from torch import Tensor @@ -492,19 +493,18 @@ def __init__(self, def forward(self, x): c1, c2, c3, c4 = x n, _, h, w = c4.shape - reshape_size = (128, 128) _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() - _c4 = F.interpolate(input=_c4, size=reshape_size, mode='bilinear', align_corners=False) + _c4 = F.interpolate(input=_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() - _c3 = F.interpolate(input=_c3, size=reshape_size, mode='bilinear', align_corners=False) + _c3 = F.interpolate(input=_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() - _c2 = F.interpolate(input=_c2, size=reshape_size, mode='bilinear', align_corners=False) + _c2 = F.interpolate(input=_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() - _c1 = F.interpolate(input=_c1, size=reshape_size, mode='bilinear', align_corners=False) + _c1 = F.interpolate(input=_c1, size=c1.size()[2:], mode='bilinear', align_corners=False) _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) x = self.dropout(_c) @@ -538,12 +538,17 @@ def __init__(self, else: raise ValueError(f"Unknown encoder: {encoder}") + self.neck = MultiLevelNeck(in_channels=self.in_channels, + out_channels=self.embedding_dim, + scales=[4, 2, 1, 0.5]) + self.decoder = Decoder(in_channels=self.in_channels, embedding_dim=self.embedding_dim, num_classes=num_classes) def forward(self, x): image_size = x.shape[2:] x = self.encoder(x) + x = self.neck(x) x = self.decoder(x) x = F.interpolate(input=x, size=image_size, scale_factor=None, mode='bilinear', align_corners=False) return x @@ -551,7 +556,7 @@ def forward(self, x): if __name__ == '__main__': batch_size = 8 - img = torch.rand(batch_size, 4, 512, 512) + img = torch.rand(batch_size, 3, 512, 512) model = DOFASeg(encoder='dofa_base', pretrained=True, image_size=(512, 512),