diff --git a/repvgg.py b/repvgg.py index bddbf2e..e4edf86 100644 --- a/repvgg.py +++ b/repvgg.py @@ -181,7 +181,7 @@ def forward(self, x): out = self.stage3(out) out = self.stage4(out) out = self.gap(out) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = self.linear(out) return out