Skip to content

Commit

Permalink
fix orbax to hf converter
Browse files Browse the repository at this point in the history
  • Loading branch information
khatwanimohit committed Dec 23, 2024
1 parent 6ec3368 commit 0ac022f
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 86 deletions.
10 changes: 6 additions & 4 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@
import checkpointing
from generate_param_only_checkpoint import _read_train_checkpoint
import llama_or_mistral_ckpt
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig


def unpermute_from_match_maxtext_rope(arr):
"""
Function to get the RoPE values in correct ordering
"""
split_size = arr.shape[-1] // 2 # Assuming half for evens, half for odds
evens = arr[..., :split_size]
odds = arr[..., split_size:]
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)


Expand All @@ -77,6 +76,9 @@ def load_hf_model(model_size):
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
elif model_size == "mixtral-8x7b":
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto")
elif model_size == "llama3.1-8b":
config = AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B")
model = AutoModelForCausalLM.from_config(config)
else:
raise NotImplementedError
return model
Expand Down
76 changes: 0 additions & 76 deletions MaxText/scratch_code/golden_llama3-70b_export.py

This file was deleted.

86 changes: 86 additions & 0 deletions MaxText/scratch_code/golden_llama3_1_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Usage: python3 golden_llama3-70b_export.py --model-id meta-llama/Meta-Llama-3-70B --output-path llama3-70b/golden_logits/golden_data_llama3-70b.jsonl
"""

import os
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import jsonlines
from google.cloud import storage

# Load the tokenizer and model from Hugging Face


def upload_blob(bucket_name, source_file_name, destination_blob_name):
"""Uploads a file to the bucket."""
storage_client = storage.Client()
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(destination_blob_name)

blob.upload_from_filename(source_file_name)


def save_golden_logits(model_id, output_path):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
)

# Your prompt text
prompt_texts = ["I love to"]
all_data_to_save = []

for prompt_text in prompt_texts:
# Encode the prompt text
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")

# Get the logits for the prompt + completion
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits

# Convert logits to fp32
logits = logits.cpu().numpy().astype("float32")

# Prepare data to be saved
data_to_save = {
"prompt": prompt_text,
"tokens": input_ids.tolist()[0],
"logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization
}
all_data_to_save.append(data_to_save)

with jsonlines.open(output_path, "w") as f:
f.write_all(all_data_to_save)

upload_blob("maxtext-llama", output_path, "Llama3_1_8B/golden-logits/" + output_path)
print("File {} uploaded to {}.".format(output_path, "Llama3_1_8B/golden-logits/" + output_path))
os.remove(output_path)


def main(raw_args=None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, required=False, default="meta-llama/Llama-3.1-8B")
parser.add_argument("--output-path", type=str, required=True)
args = parser.parse_args(raw_args)
save_golden_logits(args.model_id, args.output_path)


if __name__ == "__main__":
main()
16 changes: 10 additions & 6 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def main(config, test_args):
model = models.Transformer(config, mesh=mesh, quant=quant)
state, _ = max_utils.setup_decode_state(model, config, rng1, mesh, None)

input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl"
if test_args.golden_logits_path == "":
input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl"
else:
input_golden_data_path = test_args.golden_logits_path
with jsonlines.open(input_golden_data_path, "r") as f:
golden_data = list(f)

Expand All @@ -102,8 +105,8 @@ def main(config, test_args):
rngs={"aqt": init_rng},
)
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)
max_logging.log(f"{golden_logits[0]=}")
max_logging.log(f"{full_train_logits[0, 0, :]=}")
max_logging.log(f"{golden_logits[2]=}")
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}"
Expand All @@ -112,8 +115,8 @@ def main(config, test_args):
model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1)

max_logging.log(f"{golden_probabilities[0]=}")
max_logging.log(f"{model_probabilities[0]=}")
max_logging.log(f"{golden_probabilities[1]=}")
max_logging.log(f"{model_probabilities[1]=}")

kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1)
max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}")
Expand Down Expand Up @@ -143,11 +146,12 @@ def main(config, test_args):
parser.add_argument("--rtol", type=float, required=False, default=0.1)
parser.add_argument("--token_size", type=int, required=False)
parser.add_argument("--max_kl_div", type=float, required=False, default=None)
parser.add_argument("--golden_logits_path", type=str, required=False, default="")
test_args, _ = parser.parse_known_args()

# Remove args defined in this test file to avoid error from pyconfig
model_args = sys.argv
to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div"]
to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div", "--golden_logits_path"]
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

Expand Down
70 changes: 70 additions & 0 deletions MaxText/weight_inspector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

r"""This is to inspect/analyze two weights with the same structure to find differences.
This assumes weights are dumped in a pickle file
Usage:
python MaxText/weight_inspector.py --lhs left_hand.pkl --rhs right_hand.pkl
"""

import argparse
import pickle
import numpy as np
import torch
import max_logging


def inspect_weights(left_path, right_path):
"""Load the pickle files and compare contents."""
with open(left_path, "rb") as file:
left_weights = pickle.load(file)

with open(right_path, "rb") as file:
right_weights = pickle.load(file)
assert sorted(left_weights.keys()) == sorted(
right_weights.keys()
), f"Weights structure does not match! {list(set(left_weights.keys()).symmetric_difference(right_weights.keys()))}"

mismatched_keys = []
# Iterate through keys common to both dictionaries
for key in left_weights.keys() & right_weights.keys(): # Intersection of keys
if ".0." in key: # check only layer 0 of the model
assert (
left_weights[key].shape == right_weights[key].shape
), f"Mismatched shapes left {left_weights[key].shape}, right right_weights[key].shape"
if not np.allclose(
left_weights[key].type(torch.float16).numpy(), right_weights[key].type(torch.float16).numpy(), atol=1e-8
):
mismatched_keys.append(key)

if mismatched_keys:
max_logging.log("Contents of mismatched keys")
for key in mismatched_keys:
max_logging.log(f"Key: {key}")
max_logging.log(f"{left_weights[key]=}")
max_logging.log(f"{right_weights[key]=}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--lhs", type=str, required=True)
parser.add_argument("--rhs", type=str, required=True)

args = parser.parse_args()

inspect_weights(args.lhs, args.rhs)

args = parser.parse_args()
36 changes: 36 additions & 0 deletions end_to_end/tpu/test_orbax_to_hf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash

# This script is to test the flow of MaxText/llama_mistral_mixtral_orbax_to_hf.py.
# Steps in the script:
# 1. Convert MaxText orbax ckpt to HF using MaxText/llama_mistral_mixtral_orbax_to_hf.py
# 2. Confirm the logits match for MaxText orbax ckpt and the new HF ckpt created in step 2.
set -ex
export MODEL_VARIATION='llama3.1-8b'

export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/2024-12-18-17-35

export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items
export RUN_NAME=unscann_llama3.1
# We defined path to unscanned checkpoint created in 1_test_llama3.1_8b.sh
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items

# converting MaxText orbax ckpt to HF

JAX_PLATFORMS=cpu python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs \
load_parameters_path=gs://runner-maxtext-logs/2024-12-18-17-35/llama3.1-8b/scanned_chkpt/0/items run_name=convert_to_hf \
model_name=llama3.1-8b hf_model_path=/home/mohitkhatwani/maxtext/hf_llama3.1_new/


python MaxText/scratch_code/golden_llama3-70b_export.py --model-id /home/mohitkhatwani/maxtext/hf_llama3.1_new/ --output-path golden_data_new_llama3_1_8b.jsonl

pushd MaxText/test_assets
gcloud storage cp gs://maxtext-llama/Llama3_1_8B/golden-logits/golden_data_new_llama3_1_8b.jsonl .
popd

# comparing logits of the HF ckpt above

python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} \
tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} \
run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=3 max_target_length=4 \
dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false \
scan_layers=false --golden_logits_path="MaxText/test_assets/golden_data_new_llama3_1_8b.jsonl"

0 comments on commit 0ac022f

Please sign in to comment.