Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

threshold appear to be nan during the training process #27

Open
knightyxp opened this issue Feb 28, 2022 · 2 comments
Open

threshold appear to be nan during the training process #27

knightyxp opened this issue Feb 28, 2022 · 2 comments

Comments

@knightyxp
Copy link

knightyxp commented Feb 28, 2022

Hi tao han:
I am a graduate student in SEU, trying to replace the backbone of IIM (VGG16_FPN or HRNet) to my Transformer crowd counting model. However, even I low the initial lr 2 1e-6 to 1e-7 in SHA. the threshold even appears to be NAN in the 700 epoch. Also, the best MAE is only 126, which is far away from my model combined with other losses (more than MSE) on SHA.
I noticed that in this link #7 (comment) you have mentioned that we also could lower the initial threshold, I wonder to sure whether is the initial weight 0.5 in the Binarized module. But even I change the initial weight to 0.4, the t_max also starts with 0.54. I get confused with the Binarized module. looking forward to your reply, my email is [email protected]/ [email protected]

> class BinarizedModule(nn.Module):
>   def __init__(self, input_channels=720):
>     super(BinarizedModule, self).__init__()
>     self.Threshold_Module = nn.Sequential(
>         nn.Conv2d(input_channels, 256, kernel_size=3, stride=1, padding=1, bias=False),
> adding=0, bias=False),
>         nn.AvgPool2d(15, stride=1, padding=7),
>     )
>     self.sig = compressedSigmoid()
>     #Change the threshold org to 0.4
>     self.weight = nn.Parameter(torch.Tensor(1).fill_(0.4),requires_grad=True)
>     self.bias = nn.Parameter(torch.Tensor(1).fill_(0), requires_grad=True)
@taohan10200
Copy link
Owner

Thanks for your attention, I am sorry that I just see this issue. if lowing the initial lr doesn't work. I suggest that you can try to change the activation function in the last layer of BN Mudule. I had changed the compressed sigmoid function to a linear function with limiting the output to [0.25,0.9], it seems that the NAN will no longer appear.

@knightyxp
Copy link
Author

knightyxp commented Apr 10, 2022

Hi han tao:
I just tried a linear function, I wonder whether the parameter of the linear active function is learnable(the weight and the bias), actually, when I tried an unlearnable limitation linear function with the output value limited in [0.25,0.9] do not work, the value of the threshold appears to be too small even I magnify the threshold 1000 times.

the linear function is like this

``
class linear_limitation(nn.Module):
def init(self, para=0.75, bias=0.15):
super(linear_limitation, self).init()

        #self.weight = Parameter(torch.Tensor(self.x.shape, self.x.shape))
        #self.bias = Parameter(torch.Tensor(self.x.shape))
        #self.reset_parameters()

    def reset_parameters(self):
      init.kaiming_uniform_(self.weight, a=math.sqrt(5))
      if self.bias is not None:
          fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
          bound = 1 / math.sqrt(fan_in)
          init.uniform_(self.bias, -bound, bound)
          


    def forward(self, x):
        zero = torch.zeros_like(x)

        #output = x.matmul(self.weight)+ self.bias
        
        output = torch.where(x>0.09, zero, x)
        output = torch.where(output < 0.025, zero, output)

        return output   

``
the threshold value before the activate function be like this
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants