diff --git a/mmyolo/models/layers/yolo_bricks.py b/mmyolo/models/layers/yolo_bricks.py index 19175be1a..4f4c3e366 100644 --- a/mmyolo/models/layers/yolo_bricks.py +++ b/mmyolo/models/layers/yolo_bricks.py @@ -1504,11 +1504,27 @@ def __init__( def forward(self, x: Tensor) -> Tensor: """Forward process.""" + + #; +---------------------------------+ + #; | the original design uses `list` | + #; | which is a dynamic operation | + #; | could not traced by `torch.fx` | + #; +---------------------------------+ + # x_main = self.main_conv(x) + # x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) + # x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) + # return self.final_conv(torch.cat(x_main, 1)) + + #; the following design keeps the same calculation + #; but in a static way so that `torch.fx` could trace it + #; x = (bs, 64, h, w) + #; x_main = (bs, 64, h, w) x_main = self.main_conv(x) - x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) - x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) - return self.final_conv(torch.cat(x_main, 1)) - + x_main_first_32, x_main_last_32 = \ + x_main.split((self.mid_channels, self.mid_channels), dim=1) + for block in self.blocks: + x_main = torch.cat([x_main, block(x_main_last_32)], dim=1) + return self.final_conv(x_main) class BiFusion(nn.Module): """BiFusion Block in YOLOv6.