From bd9a22797b95074f9856c6531844341b3d41db98 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Sun, 22 Dec 2024 02:19:25 +0000 Subject: [PATCH] fix orbax to hf converter --- MaxText/llama_mistral_mixtral_orbax_to_hf.py | 10 ++- .../scratch_code/golden_llama3-70b_export.py | 76 ---------------- .../scratch_code/golden_llama3_1_export.py | 86 +++++++++++++++++++ MaxText/tests/forward_pass_logit_checker.py | 16 ++-- MaxText/weight_inspector.py | 70 +++++++++++++++ end_to_end/tpu/test_orbax_to_hf.sh | 36 ++++++++ 6 files changed, 208 insertions(+), 86 deletions(-) delete mode 100644 MaxText/scratch_code/golden_llama3-70b_export.py create mode 100644 MaxText/scratch_code/golden_llama3_1_export.py create mode 100644 MaxText/weight_inspector.py create mode 100644 end_to_end/tpu/test_orbax_to_hf.sh diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index ec126bd4a..95227f51e 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -45,16 +45,15 @@ import checkpointing from generate_param_only_checkpoint import _read_train_checkpoint import llama_or_mistral_ckpt -from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM +from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig def unpermute_from_match_maxtext_rope(arr): """ Function to get the RoPE values in correct ordering """ - split_size = arr.shape[-1] // 2 # Assuming half for evens, half for odds - evens = arr[..., :split_size] - odds = arr[..., split_size:] + evens = arr[..., ::2] + odds = arr[..., 1::2] return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) @@ -77,6 +76,9 @@ def load_hf_model(model_size): model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") elif model_size == "mixtral-8x7b": model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto") + elif model_size == "llama3.1-8b": + config = AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B") + model = AutoModelForCausalLM.from_config(config) else: raise NotImplementedError return model diff --git a/MaxText/scratch_code/golden_llama3-70b_export.py b/MaxText/scratch_code/golden_llama3-70b_export.py deleted file mode 100644 index 18db403ee..000000000 --- a/MaxText/scratch_code/golden_llama3-70b_export.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Copyright 2024 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -Usage: python3 golden_llama3-70b_export.py --model-id meta-llama/Meta-Llama-3-70B --output-path llama3-70b/golden_logits/golden_data_llama3-70b.jsonl -""" - -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM -import jsonlines -from google.cloud import storage - -# Load the tokenizer and model from Hugging Face - -model_id = "meta-llama/Meta-Llama-3-70B" - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float32, -) - - -# Your prompt text -prompt_texts = ["I love to"] -all_data_to_save = [] - -output_path = "golden_data_llama3-70b.jsonl" - - -for prompt_text in prompt_texts: - # Encode the prompt text - input_ids = tokenizer.encode(prompt_text, return_tensors="pt") - - # Get the logits for the prompt + completion - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits - - # Convert logits to fp32 - logits = logits.cpu().numpy().astype("float32") - - # Prepare data to be saved - data_to_save = { - "prompt": prompt_text, - "tokens": input_ids.tolist()[0], - "logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization - } - all_data_to_save.append(data_to_save) - -with jsonlines.open(output_path, "w") as f: - f.write_all(all_data_to_save) - - -def upload_blob(bucket_name, source_file_name, destination_blob_name): - """Uploads a file to the bucket.""" - storage_client = storage.Client() - bucket = storage_client.get_bucket(bucket_name) - blob = bucket.blob(destination_blob_name) - - blob.upload_from_filename(source_file_name) - - -upload_blob("maxtext-llama", output_path, "llama3-70b/golden-logits/" + output_path) -print("File {} uploaded to {}.".format(output_path, "llama3-70b/golden-logits/" + output_path)) diff --git a/MaxText/scratch_code/golden_llama3_1_export.py b/MaxText/scratch_code/golden_llama3_1_export.py new file mode 100644 index 000000000..79559bf9c --- /dev/null +++ b/MaxText/scratch_code/golden_llama3_1_export.py @@ -0,0 +1,86 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Usage: python3 golden_llama3-70b_export.py --model-id meta-llama/Meta-Llama-3-70B --output-path llama3-70b/golden_logits/golden_data_llama3-70b.jsonl +""" + +import os +import torch +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM +import jsonlines +from google.cloud import storage + +# Load the tokenizer and model from Hugging Face + + +def upload_blob(bucket_name, source_file_name, destination_blob_name): + """Uploads a file to the bucket.""" + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + blob = bucket.blob(destination_blob_name) + + blob.upload_from_filename(source_file_name) + + +def save_golden_logits(model_id, output_path): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B") + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float32, + ) + + # Your prompt text + prompt_texts = ["I love to"] + all_data_to_save = [] + + for prompt_text in prompt_texts: + # Encode the prompt text + input_ids = tokenizer.encode(prompt_text, return_tensors="pt") + + # Get the logits for the prompt + completion + with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits + + # Convert logits to fp32 + logits = logits.cpu().numpy().astype("float32") + + # Prepare data to be saved + data_to_save = { + "prompt": prompt_text, + "tokens": input_ids.tolist()[0], + "logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization + } + all_data_to_save.append(data_to_save) + + with jsonlines.open(output_path, "w") as f: + f.write_all(all_data_to_save) + + upload_blob("maxtext-llama", output_path, "Llama3_1_8B/golden-logits/" + output_path) + print("File {} uploaded to {}.".format(output_path, "Llama3_1_8B/golden-logits/" + output_path)) + os.remove(output_path) + + +def main(raw_args=None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-id", type=str, required=False, default="meta-llama/Llama-3.1-8B") + parser.add_argument("--output-path", type=str, required=True) + args = parser.parse_args(raw_args) + save_golden_logits(args.model_id, args.output_path) + + +if __name__ == "__main__": + main() diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index 479f962c4..fc3870c26 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -86,7 +86,10 @@ def main(config, test_args): model = models.Transformer(config, mesh=mesh, quant=quant) state, _ = max_utils.setup_decode_state(model, config, rng1, mesh, None) - input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl" + if test_args.golden_logits_path == "": + input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl" + else: + input_golden_data_path = test_args.golden_logits_path with jsonlines.open(input_golden_data_path, "r") as f: golden_data = list(f) @@ -102,8 +105,8 @@ def main(config, test_args): rngs={"aqt": init_rng}, ) full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits) - max_logging.log(f"{golden_logits[0]=}") - max_logging.log(f"{full_train_logits[0, 0, :]=}") + max_logging.log(f"{golden_logits[2]=}") + max_logging.log(f"{full_train_logits[0, 2, :]=}") token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0] max_logging.log( f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}" @@ -112,8 +115,8 @@ def main(config, test_args): model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1) golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1) - max_logging.log(f"{golden_probabilities[0]=}") - max_logging.log(f"{model_probabilities[0]=}") + max_logging.log(f"{golden_probabilities[1]=}") + max_logging.log(f"{model_probabilities[1]=}") kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1) max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}") @@ -143,11 +146,12 @@ def main(config, test_args): parser.add_argument("--rtol", type=float, required=False, default=0.1) parser.add_argument("--token_size", type=int, required=False) parser.add_argument("--max_kl_div", type=float, required=False, default=None) + parser.add_argument("--golden_logits_path", type=str, required=False, default="") test_args, _ = parser.parse_known_args() # Remove args defined in this test file to avoid error from pyconfig model_args = sys.argv - to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div"] + to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div", "--golden_logits_path"] for arg in to_remove_args: model_args = [s for s in model_args if not s.startswith(arg)] diff --git a/MaxText/weight_inspector.py b/MaxText/weight_inspector.py new file mode 100644 index 000000000..6c6e817f9 --- /dev/null +++ b/MaxText/weight_inspector.py @@ -0,0 +1,70 @@ +""" + Copyright 2024 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +r"""This is to inspect/analyze two weights with the same structure to find differences. +This assumes weights are dumped in a pickle file + +Usage: + +python MaxText/weight_inspector.py --lhs left_hand.pkl --rhs right_hand.pkl + +""" + +import argparse +import pickle +import numpy as np +import torch +import max_logging + + +def inspect_weights(left_path, right_path): + """ Load the pickle files and compare contents.""" + with open(left_path, "rb") as file: + left_weights = pickle.load(file) + + with open(right_path, "rb") as file: + right_weights = pickle.load(file) + assert sorted(left_weights.keys()) == sorted( + right_weights.keys() + ), f"Weights structure does not match! {list(set(left_weights.keys()).symmetric_difference(right_weights.keys()))}" + + mismatched_keys = [] + # Iterate through keys common to both dictionaries + for key in left_weights.keys() & right_weights.keys(): # Intersection of keys + if ".0." in key: # check only layer 0 of the model + assert ( + left_weights[key].shape == right_weights[key].shape + ), f"Mismatched shapes left {left_weights[key].shape}, right right_weights[key].shape" + if not np.allclose( + left_weights[key].type(torch.float16).numpy(), right_weights[key].type(torch.float16).numpy(), atol=1e-8 + ): + mismatched_keys.append(key) + + if mismatched_keys: + max_logging.log("Contents of mismatched keys") + for key in mismatched_keys: + max_logging.log(f"Key: {key}") + max_logging.log(f"{left_weights[key]=}") + max_logging.log(f"{right_weights[key]=}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--lhs", type=str, required=True) + parser.add_argument("--rhs", type=str, required=True) + + args = parser.parse_args() + + inspect_weights(args.lhs, args.rhs) + + args = parser.parse_args() diff --git a/end_to_end/tpu/test_orbax_to_hf.sh b/end_to_end/tpu/test_orbax_to_hf.sh new file mode 100644 index 000000000..9f462172a --- /dev/null +++ b/end_to_end/tpu/test_orbax_to_hf.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# This script is to test the flow of MaxText/llama_mistral_mixtral_orbax_to_hf.py. +# Steps in the script: +# 1. Convert MaxText orbax ckpt to HF using MaxText/llama_mistral_mixtral_orbax_to_hf.py +# 2. Confirm the logits match for MaxText orbax ckpt and the new HF ckpt created in step 2. +set -ex +export MODEL_VARIATION='llama3.1-8b' + +export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/2024-12-18-17-35 + +export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items +export RUN_NAME=unscann_llama3.1 +# We defined path to unscanned checkpoint created in 1_test_llama3.1_8b.sh +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items + +# converting MaxText orbax ckpt to HF + +JAX_PLATFORMS=cpu python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs \ + load_parameters_path=gs://runner-maxtext-logs/2024-12-18-17-35/llama3.1-8b/scanned_chkpt/0/items run_name=convert_to_hf \ + model_name=llama3.1-8b hf_model_path=/home/mohitkhatwani/maxtext/hf_llama3.1_new/ + + +python MaxText/scratch_code/golden_llama3-70b_export.py --model-id /home/mohitkhatwani/maxtext/hf_llama3.1_new/ --output-path golden_data_new_llama3_1_8b.jsonl + +pushd MaxText/test_assets +gcloud storage cp gs://maxtext-llama/Llama3_1_8B/golden-logits/golden_data_new_llama3_1_8b.jsonl . +popd + +# comparing logits of the HF ckpt above + +python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} \ + tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} \ + run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=3 max_target_length=4 \ + dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false \ + scan_layers=false --golden_logits_path="MaxText/test_assets/golden_data_new_llama3_1_8b.jsonl"