Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MoE] notebook for numerical verification #1026

Merged
merged 1 commit into from
Dec 27, 2024
Merged

Conversation

ZhiyuLi-goog
Copy link
Collaborator

@ZhiyuLi-goog ZhiyuLi-goog commented Nov 12, 2024

Description

In addition to the typo, looks like we are missing this normalization in maxtext, which is to re-scale top_k_weights to a normalized one where their sum equals 1. I was able to match layer outputs after adding the normalization. This normalization won't affect training/inferencing since this is a constant term (per each token), which won't change softmax probability. But we can still add it for better alignment.

Update:
fix typo and add normalization for top_k_weights was covered in these PRs #1100 & #1064
Just add numerical verification sheet in this PR.

Tests

Tested locally.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@gobbleturk
Copy link
Collaborator

Thanks for the fix Zhiyu. I'll leave for @RissyRan to review

@RissyRan
Copy link
Collaborator

Thanks Zhiyu for the fix! I am fine to add this normalization if this makes us convenient to compare weights. Could you add this top_k_weights /= top_k_weights.sum(-1, keepdims=True) before this line weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)? It was not used anywhere before that.

@RissyRan
Copy link
Collaborator

I think we also need to add this normalization to somewhere [here]

softmax_probs *= combined_expert_mask
) for dropping. softmax_probs *= combined_expert_mask

Have you checked the benchmark scores with and without this normalization? using correct ckpt on 8x22b or 8x7b.

@ZhiyuLi-goog
Copy link
Collaborator Author

Thanks Zhiyu for the fix! I am fine to add this normalization if this makes us convenient to compare weights. Could you add this top_k_weights /= top_k_weights.sum(-1, keepdims=True) before this line weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)? It was not used anywhere before that.

Done.

@ZhiyuLi-goog
Copy link
Collaborator Author

ZhiyuLi-goog commented Nov 15, 2024

I think we also need to add this normalization to somewhere [here]

softmax_probs *= combined_expert_mask

) for dropping. softmax_probs *= combined_expert_mask
Have you checked the benchmark scores with and without this normalization? using correct ckpt on 8x22b or 8x7b.

Have checked the benchmark scores the results are the same with and without this normalization.

Additionally,added a numerical verification notebook.
Currently able get ~0.1 tolerance in logits for both megablox=True or False.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@ZhiyuLi-goog
Copy link
Collaborator Author

@gobbleturk could you review it at your convenience?

@ZhiyuLi-goog
Copy link
Collaborator Author

@gobbleturk could you review it at your convenience?

Hi, @gobbleturk I need code owner's review as the final step of this PR, thank you!

@ZhiyuLi-goog ZhiyuLi-goog changed the title [MoE] fix typo and add normalization for top_k_weights [MoE] notebook for numerical verification Dec 27, 2024
@copybara-service copybara-service bot merged commit c106fe1 into main Dec 27, 2024
15 of 17 checks passed
@copybara-service copybara-service bot deleted the lizhiyu/fix_moe branch December 27, 2024 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants