Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bfp16 support to inhibit transforming to fp32 when using llama3 #31

Merged
merged 1 commit into from
Jul 15, 2024

Conversation

StarConnor
Copy link
Collaborator

For Llama3, it is trained on bfp16, so we should not transform bfp16 to fp32 during inference. So I change the two precision transformations in transformer_lens. Basically, I add torch.bfloat16 in the precision check list to avoid transformation to fp32.

  1. Attention
    In abstract_attention.py.
  2. Layer Norm
    In rms_norm.py, and rms_norm.py (because Llama3 use RMSnorm)
    Note: I did not change the precision check list in layernorm_pre.py and layernorm.py because Llama3 doesn't use them. Should I also modified these two files?

@StarConnor StarConnor requested a review from dest1n1s July 14, 2024 13:01
@dest1n1s
Copy link
Collaborator

From my perspective, changing other precision checks could be postponed to avoid other issues.

@dest1n1s dest1n1s merged commit 4a14ae0 into main Jul 15, 2024
1 check passed
@dest1n1s dest1n1s deleted the tl_dtype branch July 15, 2024 04:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants