Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix orbax to hf converter for Llama3.1-8B #1123

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Dec 23, 2024

Description

Fix unpermute_from_match_maxtext_rope for orbax chkpt to HF chkpt.
Tested for Llama3.1-8B using end_to_end/tpu/test_orbax_to_hf.sh.

This PR also introduces a script (weight_inspector.py) for comparing two weights with the same key structure. The weights should be dumped as pickle files beforehand.

FIXES: b/384165076

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great - Thank you Mohit!

Also, could you please add orbax->hf (for e.g.: orbax_to_hf_logit_checker.py) test for the nightly tests?
For e.g., it can go directly after this. But ideally, we want to put this is another file and share the CONVERTED_CHECKPOINT just like we share CONVERTED_CHECKPOINT for the 2_test_llama3.1_8b.sh and run this 3rd script at a different cadence than the rest

For llama2-7b it would be
llama2-7B

See the License for the specific language governing permissions and
limitations under the License.

Usage: python3 golden_llama3-70b_export.py --model-id meta-llama/Meta-Llama-3-70B --output-path llama3-70b/golden_logits/golden_data_llama3-70b.jsonl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please update the usage to Llama 3.1


JAX_PLATFORMS=cpu python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs \
load_parameters_path=gs://runner-maxtext-logs/2024-12-18-17-35/llama3.1-8b/scanned_chkpt/0/items run_name=convert_to_hf \
model_name=llama3.1-8b hf_model_path=/home/mohitkhatwani/maxtext/hf_llama3.1_new/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please update the hf_model_path name and the dates in the paths


# comparing logits of the HF ckpt above

python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is best to load the converted model in HF and ensure the logits match.

Please see an example here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants