Skip to content

Commit

Permalink
Refactor DOFASeg model to integrate MultiLevelNeck and adjust interpo…
Browse files Browse the repository at this point in the history
…lation sizes

This commit updates the DOFASeg model by adding the MultiLevelNeck for improved feature extraction across multiple scales. The interpolation sizes in the Decoder class have been modified to dynamically match the size of the first input tensor, enhancing flexibility. Additionally, the input tensor shape in the main execution block has been corrected from 4 to 3 channels.
  • Loading branch information
valhassan committed Nov 25, 2024
1 parent 29cdd1b commit 814cd10
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions geo_deep_learning/models/dofa/dofa_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from neckhead import MultiLevelNeck
from pathlib import Path
from timm.models.vision_transformer import Block
from torch import Tensor
Expand Down Expand Up @@ -492,19 +493,18 @@ def __init__(self,
def forward(self, x):
c1, c2, c3, c4 = x
n, _, h, w = c4.shape
reshape_size = (128, 128)

_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous()
_c4 = F.interpolate(input=_c4, size=reshape_size, mode='bilinear', align_corners=False)
_c4 = F.interpolate(input=_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)

_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous()
_c3 = F.interpolate(input=_c3, size=reshape_size, mode='bilinear', align_corners=False)
_c3 = F.interpolate(input=_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)

_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous()
_c2 = F.interpolate(input=_c2, size=reshape_size, mode='bilinear', align_corners=False)
_c2 = F.interpolate(input=_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)

_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous()
_c1 = F.interpolate(input=_c1, size=reshape_size, mode='bilinear', align_corners=False)
_c1 = F.interpolate(input=_c1, size=c1.size()[2:], mode='bilinear', align_corners=False)

_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
x = self.dropout(_c)
Expand Down Expand Up @@ -538,20 +538,25 @@ def __init__(self,
else:
raise ValueError(f"Unknown encoder: {encoder}")

self.neck = MultiLevelNeck(in_channels=self.in_channels,
out_channels=self.embedding_dim,
scales=[4, 2, 1, 0.5])

self.decoder = Decoder(in_channels=self.in_channels,
embedding_dim=self.embedding_dim,
num_classes=num_classes)
def forward(self, x):
image_size = x.shape[2:]
x = self.encoder(x)
x = self.neck(x)
x = self.decoder(x)
x = F.interpolate(input=x, size=image_size, scale_factor=None, mode='bilinear', align_corners=False)
return x


if __name__ == '__main__':
batch_size = 8
img = torch.rand(batch_size, 4, 512, 512)
img = torch.rand(batch_size, 3, 512, 512)
model = DOFASeg(encoder='dofa_base',
pretrained=True,
image_size=(512, 512),
Expand Down

0 comments on commit 814cd10

Please sign in to comment.