Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Almost never a desirable behavior to run LayerNormGeneral in FP16 #17

Open
WangTaoAs opened this issue Nov 13, 2024 · 1 comment
Open

Comments

@WangTaoAs
Copy link

WangTaoAs commented Nov 13, 2024

Hi, thanks for your great work
I try to use your Network to train in a custom dataset and it works well. however, when i use FP16 to do inference, the performance drops a lot. I find that the self-implemented LayerNormGeneral function contribute to big errors between FP16 and FP32. and I try to use LayerNorm implemented by apex, the output can be the same. Is there any solution to solve this problem?

`class LayerNormGeneral(nn.Module):
def init(self, affine_shape=None, normalized_dim=(-1, ), scale=True,
bias=True, eps=1e-5):
super().init()
self.normalized_dim = normalized_dim
self.use_scale = scale
self.use_bias = bias
self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None
self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None
self.eps = eps

def forward(self, x):
    c = x - x.mean(self.normalized_dim, keepdim=True)
    s = c.pow(2).mean(self.normalized_dim, keepdim=True)
    x = c / torch.sqrt(s + self.eps)
    # if self.use_scale:
    x = x * self.weight
    # if self.use_bias:
        # x = x + self.bias
    return x`
@yuweihao
Copy link
Collaborator

yuweihao commented Dec 1, 2024

Hi @WangTaoAs , sorry that I have no experience in FP16 inference. You may try this norm implementation to train your model https://github.com/sail-sg/metaformer/blob/main/metaformer_baselines.py#L367

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants