Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix orbax to hf converter for Llama3.1-8B #1123

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@
import checkpointing
from generate_param_only_checkpoint import _read_train_checkpoint
import llama_or_mistral_ckpt
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig


def unpermute_from_match_maxtext_rope(arr):
"""
Function to get the RoPE values in correct ordering
"""
split_size = arr.shape[-1] // 2 # Assuming half for evens, half for odds
evens = arr[..., :split_size]
odds = arr[..., split_size:]
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)


Expand All @@ -77,6 +76,9 @@ def load_hf_model(model_size):
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
elif model_size == "mixtral-8x7b":
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto")
elif model_size == "llama3.1-8b":
config = AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B")
model = AutoModelForCausalLM.from_config(config)
else:
raise NotImplementedError
return model
Expand Down
76 changes: 0 additions & 76 deletions MaxText/scratch_code/golden_llama3-70b_export.py

This file was deleted.

86 changes: 86 additions & 0 deletions MaxText/scratch_code/golden_llama3_1_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Usage: python3 golden_llama3-70b_export.py --model-id meta-llama/Meta-Llama-3-70B --output-path llama3-70b/golden_logits/golden_data_llama3-70b.jsonl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please update the usage to Llama 3.1

"""

import os
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import jsonlines
from google.cloud import storage

# Load the tokenizer and model from Hugging Face


def upload_blob(bucket_name, source_file_name, destination_blob_name):
"""Uploads a file to the bucket."""
storage_client = storage.Client()
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(destination_blob_name)

blob.upload_from_filename(source_file_name)


def save_golden_logits(model_id, output_path):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
)

# Your prompt text
prompt_texts = ["I love to"]
all_data_to_save = []

for prompt_text in prompt_texts:
# Encode the prompt text
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")

# Get the logits for the prompt + completion
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits

# Convert logits to fp32
logits = logits.cpu().numpy().astype("float32")

# Prepare data to be saved
data_to_save = {
"prompt": prompt_text,
"tokens": input_ids.tolist()[0],
"logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization
}
all_data_to_save.append(data_to_save)

with jsonlines.open(output_path, "w") as f:
f.write_all(all_data_to_save)

upload_blob("maxtext-llama", output_path, "Llama3_1_8B/golden-logits/" + output_path)
print("File {} uploaded to {}.".format(output_path, "Llama3_1_8B/golden-logits/" + output_path))
os.remove(output_path)


def main(raw_args=None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, required=False, default="meta-llama/Llama-3.1-8B")
parser.add_argument("--output-path", type=str, required=True)
args = parser.parse_args(raw_args)
save_golden_logits(args.model_id, args.output_path)


if __name__ == "__main__":
main()
16 changes: 10 additions & 6 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def main(config, test_args):
model = models.Transformer(config, mesh=mesh, quant=quant)
state, _ = max_utils.setup_decode_state(model, config, rng1, mesh, None)

input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl"
if test_args.golden_logits_path == "":
input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl"
else:
input_golden_data_path = test_args.golden_logits_path
with jsonlines.open(input_golden_data_path, "r") as f:
golden_data = list(f)

Expand All @@ -102,8 +105,8 @@ def main(config, test_args):
rngs={"aqt": init_rng},
)
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)
max_logging.log(f"{golden_logits[0]=}")
max_logging.log(f"{full_train_logits[0, 0, :]=}")
max_logging.log(f"{golden_logits[2]=}")
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}"
Expand All @@ -112,8 +115,8 @@ def main(config, test_args):
model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1)

max_logging.log(f"{golden_probabilities[0]=}")
max_logging.log(f"{model_probabilities[0]=}")
max_logging.log(f"{golden_probabilities[1]=}")
max_logging.log(f"{model_probabilities[1]=}")

kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1)
max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}")
Expand Down Expand Up @@ -143,11 +146,12 @@ def main(config, test_args):
parser.add_argument("--rtol", type=float, required=False, default=0.1)
parser.add_argument("--token_size", type=int, required=False)
parser.add_argument("--max_kl_div", type=float, required=False, default=None)
parser.add_argument("--golden_logits_path", type=str, required=False, default="")
test_args, _ = parser.parse_known_args()

# Remove args defined in this test file to avoid error from pyconfig
model_args = sys.argv
to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div"]
to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div", "--golden_logits_path"]
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

Expand Down
70 changes: 70 additions & 0 deletions MaxText/weight_inspector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

r"""This is to inspect/analyze two weights with the same structure to find differences.
This assumes weights are dumped in a pickle file
Usage:
python MaxText/weight_inspector.py --lhs left_hand.pkl --rhs right_hand.pkl
"""

import argparse
import pickle
import numpy as np
import torch
import max_logging


def inspect_weights(left_path, right_path):
"""Load the pickle files and compare contents."""
with open(left_path, "rb") as file:
left_weights = pickle.load(file)

with open(right_path, "rb") as file:
right_weights = pickle.load(file)
assert sorted(left_weights.keys()) == sorted(
right_weights.keys()
), f"Weights structure does not match! {list(set(left_weights.keys()).symmetric_difference(right_weights.keys()))}"

mismatched_keys = []
# Iterate through keys common to both dictionaries
for key in left_weights.keys() & right_weights.keys(): # Intersection of keys
if ".0." in key: # check only layer 0 of the model
assert (
left_weights[key].shape == right_weights[key].shape
), f"Mismatched shapes left {left_weights[key].shape}, right right_weights[key].shape"
if not np.allclose(
left_weights[key].type(torch.float16).numpy(), right_weights[key].type(torch.float16).numpy(), atol=1e-8
):
mismatched_keys.append(key)

if mismatched_keys:
max_logging.log("Contents of mismatched keys")
for key in mismatched_keys:
max_logging.log(f"Key: {key}")
max_logging.log(f"{left_weights[key]=}")
max_logging.log(f"{right_weights[key]=}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--lhs", type=str, required=True)
parser.add_argument("--rhs", type=str, required=True)

args = parser.parse_args()

inspect_weights(args.lhs, args.rhs)

args = parser.parse_args()
36 changes: 36 additions & 0 deletions end_to_end/tpu/test_orbax_to_hf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash

# This script is to test the flow of MaxText/llama_mistral_mixtral_orbax_to_hf.py.
# Steps in the script:
# 1. Convert MaxText orbax ckpt to HF using MaxText/llama_mistral_mixtral_orbax_to_hf.py
# 2. Confirm the logits match for MaxText orbax ckpt and the new HF ckpt created in step 2.
set -ex
export MODEL_VARIATION='llama3.1-8b'

export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/2024-12-18-17-35

export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items
export RUN_NAME=unscann_llama3.1
# We defined path to unscanned checkpoint created in 1_test_llama3.1_8b.sh
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items

# converting MaxText orbax ckpt to HF

JAX_PLATFORMS=cpu python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs \
load_parameters_path=gs://runner-maxtext-logs/2024-12-18-17-35/llama3.1-8b/scanned_chkpt/0/items run_name=convert_to_hf \
model_name=llama3.1-8b hf_model_path=/home/mohitkhatwani/maxtext/hf_llama3.1_new/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please update the hf_model_path name and the dates in the paths



python MaxText/scratch_code/golden_llama3-70b_export.py --model-id /home/mohitkhatwani/maxtext/hf_llama3.1_new/ --output-path golden_data_new_llama3_1_8b.jsonl

pushd MaxText/test_assets
gcloud storage cp gs://maxtext-llama/Llama3_1_8B/golden-logits/golden_data_new_llama3_1_8b.jsonl .
popd

# comparing logits of the HF ckpt above

python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is best to load the converted model in HF and ensure the logits match.

Please see an example here

tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} \
run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=3 max_target_length=4 \
dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false \
scan_layers=false --golden_logits_path="MaxText/test_assets/golden_data_new_llama3_1_8b.jsonl"
Loading