Skip to content

Commit

Permalink
add debugging functionality to hf script
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jan 15, 2024
1 parent a348152 commit 9b97772
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion tests/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import json
import os
import shutil
import torch
from transformers import (
AutoModelForCausalLM,
Expand All @@ -9,7 +10,30 @@
LlamaTokenizer,
GenerationConfig,
)

######################### debugging helper functions #########################
def pre_forward_hook(module, input):
assert module.name is not None and module.decoding_step is not None
name = module.name.replace("model.", "")
print(
f"Pre-forward hook activated on module: {name}, decoding step: {module.decoding_step}"
)
print("Pre-Input: ", input[0].shape)
torch.save(
input, f"./hf_tensors/decoding_step_{module.decoding_step}_{name}.input"
)
def post_forward_hook(module, input, output):
assert module.name is not None and module.decoding_step is not None
name = module.name.replace("model.", "")
print(
f"Post-forward Hook activated for module: {name}, decoding step: {module.decoding_step}"
)
print("Post-Input/Output: ", input[0].shape, output[0].shape)
torch.save(
output, f"./hf_tensors/decoding_step_{module.decoding_step}_{name}.output"
)
print("===")
module.decoding_step += 1
##############################################################################

def main():
# Change working dir to folder storing this script
Expand All @@ -28,6 +52,11 @@ def main():
)
parser.add_argument("--do-sample", action="store_true", help="Use sampling")
parser.add_argument("--gpu", action="store_true", help="Run on GPU")
parser.add_argument(
"--inference-debugging",
action="store_true",
help="Print debugging info and save hidden states/weights to file",
)
args = parser.parse_args()
# Check if max-length is greater than 0
if args.max_length <= 0:
Expand Down Expand Up @@ -64,6 +93,27 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
generation_config = GenerationConfig.from_pretrained(args.model_name)
generation_config.do_sample = args.do_sample
################# debugging #################
if args.inference_debugging:
# Print model and configs
print(hf_config)
print(model)
# Save weights to file
shutil.rmtree("./hf_tensors")
# Check that the output folder exists
os.makedirs("./hf_tensors", exist_ok=True)
# Save weights
for name, params in model.named_parameters():
torch.save(params, f"./hf_tensors/{name}")
# params.detach().cpu().numpy().tofile(f"./hf_tensors/{name}")
# Register hooks to save per-op hidden states
for name, layer in dict(model.named_modules()).items():
layer.name = name
layer.decoding_step = 0
print(f"Adding hooks to layer {layer.name}")
layer.register_forward_pre_hook(pre_forward_hook)
layer.register_forward_hook(post_forward_hook)
###############################################
# Generate output
with open(args.output_file, "w") as f:
for i, prompt in enumerate(prompt_list):
Expand Down

0 comments on commit 9b97772

Please sign in to comment.