Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU memory leak in class ResBlock (line 44 & 46 in IFRNet_S.py) in pytorch 1.13.0 #32

Open
VicRanger opened this issue May 7, 2023 · 1 comment

Comments

@VicRanger
Copy link

I found gpu memory leak when program runs into following lines. (It might be a bug in higher version of pytorch)

https://github.com/ltkong218/IFRNet/blob/main/models/IFRNet_S.py#L44
https://github.com/ltkong218/IFRNet/blob/main/models/IFRNet_S.py#L46

officical code (which leads to GPU memory leak):

out = self.conv1(x)
out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :])
out = self.conv3(out)
out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :])
out = self.prelu(x + self.conv5(out))
return out

then I changed codes to the following, GPU mem leak disappeared. (using concat to get a new feature after each side conv)

  out = self.conv1(x)
  side_ft = out[:, :-self.side_channels, :, :]
  conv_ft = out[:, -self.side_channels:, :, :]
  conv_ft = self.conv2(conv_ft)
  out = torch.cat([side_ft, conv_ft], axis=1)
  out = self.conv3(out)
  side_ft = out[:, :-self.side_channels, :, :]
  conv_ft = out[:, -self.side_channels:, :, :]
  conv_ft = self.conv4(conv_ft)
  out = torch.cat([side_ft, conv_ft], axis=1)
  out = self.prelu(x + self.conv5(out))

my specs:
ubuntu 20.04, python 3.9, pytorch1.13.1+cu117, with gpu v100(single card)

@ltkong218
Copy link
Owner

Thanks for the feedback. This may be a bug in higher version of PyTorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants