-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MoE] notebook for numerical verification #1026
Conversation
Thanks for the fix Zhiyu. I'll leave for @RissyRan to review |
Thanks Zhiyu for the fix! I am fine to add this normalization if this makes us convenient to compare weights. Could you add this |
I think we also need to add this normalization to somewhere [here] maxtext/MaxText/layers/linears.py Line 488 in 061abd8
softmax_probs *= combined_expert_mask
Have you checked the benchmark scores with and without this normalization? using correct ckpt on 8x22b or 8x7b. |
613daa2
to
f6f6c2d
Compare
Done. |
Have checked the benchmark scores the results are the same with and without this normalization. Additionally,added a numerical verification notebook. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
f6f6c2d
to
28a634a
Compare
@gobbleturk could you review it at your convenience? |
Hi, @gobbleturk I need code owner's review as the final step of this PR, thank you! |
28a634a
to
5ea24b9
Compare
5ea24b9
to
0af8d97
Compare
Description
In addition to the typo, looks like we are missing this normalization in maxtext, which is to re-scale top_k_weights to a normalized one where their sum equals 1. I was able to match layer outputs after adding the normalization. This normalization won't affect training/inferencing since this is a constant term (per each token), which won't change softmax probability. But we can still add it for better alignment.
Update:
fix typo and add normalization for top_k_weights was covered in these PRs #1100 & #1064
Just add numerical verification sheet in this PR.
Tests
Tested locally.
Checklist
Before submitting this PR, please make sure (put X in square brackets):