From 09e454022bfc1f6e70e283ce22860f26071e3815 Mon Sep 17 00:00:00 2001 From: Jett Date: Fri, 26 Apr 2024 13:59:11 +0200 Subject: [PATCH 1/3] get_next_logprobs revamp --- scripts/generate_logprobs.sh | 28 --------- scripts/get_next_logprobs.py | 115 +++++++++++++++++++++++++++++++++++ scripts/inference.py | 105 -------------------------------- src/delphi/utils.py | 50 +++++++++++++++ tests/__init__.py | 0 tests/test_utils.py | 14 +++++ tests/utils.py | 6 ++ 7 files changed, 185 insertions(+), 133 deletions(-) delete mode 100644 scripts/generate_logprobs.sh create mode 100755 scripts/get_next_logprobs.py delete mode 100644 scripts/inference.py create mode 100644 src/delphi/utils.py create mode 100644 tests/__init__.py create mode 100644 tests/test_utils.py create mode 100644 tests/utils.py diff --git a/scripts/generate_logprobs.sh b/scripts/generate_logprobs.sh deleted file mode 100644 index fc1f836a..00000000 --- a/scripts/generate_logprobs.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# Define the batch size -BATCH_SIZE=80 # This worked well in my CPU, but 200 was too much -DATASET_NAME="delphi-suite/tinystories-v2-clean-tokenized" -USERNAME="transcendingvictor" # your Hugging Face username -TOKEN="hf_aaaaaaaaaaaaaaaaaaaaaaaaaa" # your Hugging Face API token - - -# List of models -declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k" - "delphi-suite/delphi-llama2-200k" - "delphi-suite/delphi-llama2-400k" - "delphi-suite/delphi-llama2-800k" - "delphi-suite/delphi-llama2-1.6m" - "delphi-suite/delphi-llama2-3.2m" - "delphi-suite/delphi-llama2-6.4m" - "delphi-suite/delphi-llama2-12.8m" - "delphi-suite/delphi-llama2-25.6m") - -# Loop through each model and generate log probabilities -for MODEL_NAME in "${MODEL_NAMES[@]}" -do - echo "Processing $MODEL_NAME" - python scripts/inference.py "$MODEL_NAME" --batch-size "$BATCH_SIZE" --dataset-name "$DATASET_NAME" --username "$USERNAME" --token "$TOKEN" -done - -echo "All models processed." diff --git a/scripts/get_next_logprobs.py b/scripts/get_next_logprobs.py new file mode 100755 index 00000000..002264ed --- /dev/null +++ b/scripts/get_next_logprobs.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +import argparse + +import numpy as np +import torch +from datasets import Dataset +from tqdm.auto import trange +from transformers import AutoModelForCausalLM + +from delphi import utils +from delphi.eval.utils import get_all_and_next_logprobs + +torch.set_grad_enabled(False) + + +def main( + in_model_repo_id: str, + in_dataset_repo_id: str, + split: str, + feature: str, + batch_size: int, + out_repo_id: str, +): + """ + Outputs the log probabilities of the next token for each token in the dataset. + And uploads the resulting dataset to huggingface. + """ + model = AutoModelForCausalLM.from_pretrained(in_model_repo_id) + in_dataset_split = utils.load_dataset_split_sequence_int32_feature( + in_dataset_repo_id, split, feature + ) + in_dataset_split.set_format("torch") + n_seq = len(in_dataset_split) + seq_len = len(in_dataset_split[0][feature]) + logprobs = np.empty((n_seq, seq_len)) + logprobs[:, 0] = float("nan") + print("Running inference...") + for i in trange(0, n_seq, batch_size): + batch_tokens = in_dataset_split[i : i + batch_size][feature] + logprobs[i : i + batch_size, 1:] = ( + get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() + ) + + hf_dataset = Dataset.from_dict({"logprobs": [row for row in logprobs]}) + + hf_dataset.push_to_hub( + repo_id=out_repo_id, + split=utils.hf_split_to_split_name(split), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run inference and generate log probabilities." + ) + parser.add_argument( + "--in-model-repo-id", + "--im", + type=str, + required=True, + help="The model", + ) + parser.add_argument( + "--in-dataset-repo-id", + "--id", + type=str, + required=True, + help="The tokenized dataset", + ) + parser.add_argument( + "--feature", + "-f", + type=str, + required=True, + help="Name of the column containing token sequences in the input dataset", + ) + parser.add_argument( + "--split", + "-s", + type=str, + required=True, + help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'", + ) + parser.add_argument( + "--out-repo-id", + "-o", + type=str, + required=True, + help="Where to upload the next logprobs", + ) + parser.add_argument( + "--batch-size", + "-b", + type=int, + default=80, + help="How many sequences to evaluate at once", + ) + # TODO + # parser.add_argument( + # "--chunk-size", + # "-c", + # type=int, + # default=200_000, + # help="Size of the parquet chunks uploaded to HuggingFace", + # ) + args = parser.parse_args() + + main( + in_model_repo_id=args.in_model_repo_id, + in_dataset_repo_id=args.in_dataset_repo_id, + split=args.split, + feature=args.feature, + batch_size=args.batch_size, + out_repo_id=args.out_repo_id, + ) diff --git a/scripts/inference.py b/scripts/inference.py deleted file mode 100644 index abe9b48b..00000000 --- a/scripts/inference.py +++ /dev/null @@ -1,105 +0,0 @@ -import argparse -import os - -import numpy as np -import pandas as pd -import torch -from datasets import Dataset, load_dataset -from jaxtyping import Int -from tqdm.auto import tqdm -from transformers import AutoModelForCausalLM - -from delphi.eval.utils import get_all_and_next_logprobs, load_validation_dataset - -torch.set_grad_enabled(False) - - -def main( - model_name: str, - batch_size: Int, - dataset_name: str, - username: str, - funct_test: bool = False, -): - """ - Outputs the log probabilities of the next token for each token in the validation dataset. - And uploads the resulting dataset to huggingface. - Args: - - model_name: The name of the model to use for inference - - batch_size: The batch size for processing. 80 worked well in CPU. - - dataset_name: The name of the dataset from which validation set will be loaded - - username: Hugging Face API username - """ - val_ds = load_validation_dataset(dataset_name) - - model = AutoModelForCausalLM.from_pretrained(model_name) - - total_sequences = ( - len(val_ds) if not funct_test else 320 - ) # Use only 320 sequences if funct_test is True - - logprobs = np.empty((total_sequences, 513)) - logprobs[:, 0] = float("nan") - for i in tqdm(range(0, total_sequences, batch_size)): - batch_end = min(i + batch_size, total_sequences) - batch_sequences = [val_ds[j]["tokens"] for j in range(i, batch_end)] - batch_sequences_tensor = torch.tensor(batch_sequences) - - logprobs_tensor = get_all_and_next_logprobs(model, batch_sequences_tensor)[1] - logprobs[i:batch_end, 1:] = logprobs_tensor.cpu().numpy() - - df_dataset = pd.DataFrame({"logprobs": [row for row in logprobs]}) - hf_dataset = Dataset.from_pandas(df_dataset) - - # change the repo_id to your hf username in generate_logprobs.sh - # change the yout hf token in generate_logprobs.sh - - repo_id = f"{username}/{model_name.rsplit('/', 1)[-1]}-validation-logprobs" - if funct_test: - repo_id += "-funct-test" - hf_dataset.push_to_hub( - repo_id=repo_id, - split="validation", - private=False, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run inference and generate log probabilities." - ) - parser.add_argument( - "model_name", type=str, help="Model name with or without delphi-suite/ prefix" - ) - parser.add_argument( - "--batch-size", - type=int, - default=80, - help="Batch size for processing (default: 80)", - ) - parser.add_argument( - "--dataset-name", - type=str, - help="Dataset name with or without delphi-suite/ prefix", - ) - parser.add_argument( - "--username", - type=str, - help="Hugging Face API username", - ) - parser.add_argument( - "--test-funct", action="store_true", help="Enable test function mode" - ) - - args = parser.parse_args() - - if "/" not in args.model_name: - args.model_name = "delphi-suite/" + args.model_name - - main( - args.model_name, - args.batch_size, - args.dataset_name, - args.username, - args.test_funct, - ) diff --git a/src/delphi/utils.py b/src/delphi/utils.py new file mode 100644 index 00000000..05721edc --- /dev/null +++ b/src/delphi/utils.py @@ -0,0 +1,50 @@ +from typing import cast + +from datasets import Dataset, Features, Sequence, Value, load_dataset + + +def hf_split_to_split_name(split: str) -> str: + return split.split("[")[0] + + +# TODO: test load_dataset functions +def load_dataset_split_features( + repo_id: str, + split: str, + features: Features, +) -> Dataset: + dataset = load_dataset( + repo_id, + split=split, + features=features, + ) + dataset = cast(Dataset, dataset) + return dataset + + +def load_dataset_split_string_feature( + repo_id: str, + split: str, + feature_name: str, +) -> Dataset: + print("Loading string dataset") + print(f"{repo_id=}, {split=}, {feature_name=}") + return load_dataset_split_features( + repo_id, + split, + Features({feature_name: Value("string")}), + ) + + +def load_dataset_split_sequence_int32_feature( + repo_id: str, + split: str, + feature_name: str, +) -> Dataset: + print("Loading sequence int32 dataset") + print(f"{repo_id=}, {split=}, {feature_name=}") + return load_dataset_split_features( + repo_id, + split, + Features({feature_name: Sequence(Value("int32"))}), + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..597438ca --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,14 @@ +from delphi.utils import hf_split_to_split_name + +from .utils import random_string + + +def test_hf_split_to_split_name(): + random_split_name = random_string(5) + assert hf_split_to_split_name(random_split_name) == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..ed81b58a --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,6 @@ +import random +import string + + +def random_string(length: int) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) From 3799ae59a9163654881ff38a172f61d6e91b06fd Mon Sep 17 00:00:00 2001 From: Jett Date: Fri, 26 Apr 2024 15:12:42 +0200 Subject: [PATCH 2/3] UNTESTED: support revisions --- scripts/get_next_logprobs.py | 59 ++++++++++++++++++++++++++++-------- src/delphi/utils.py | 8 +++++ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/scripts/get_next_logprobs.py b/scripts/get_next_logprobs.py index 002264ed..1ba0365c 100755 --- a/scripts/get_next_logprobs.py +++ b/scripts/get_next_logprobs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import argparse +from collections.abc import Iterable import numpy as np import torch @@ -15,6 +16,7 @@ def main( in_model_repo_id: str, + revisions: Iterable[str], in_dataset_repo_id: str, split: str, feature: str, @@ -25,28 +27,45 @@ def main( Outputs the log probabilities of the next token for each token in the dataset. And uploads the resulting dataset to huggingface. """ - model = AutoModelForCausalLM.from_pretrained(in_model_repo_id) in_dataset_split = utils.load_dataset_split_sequence_int32_feature( in_dataset_repo_id, split, feature ) in_dataset_split.set_format("torch") - n_seq = len(in_dataset_split) - seq_len = len(in_dataset_split[0][feature]) + for revision in revisions: + print(f"Loading model={in_model_repo_id}, {revision=}") + model = AutoModelForCausalLM.from_pretrained( + in_model_repo_id, revision=revision + ) + logprobs_dataset = get_logprobs_single_model( + model=model, + dataset=in_dataset_split, + feature=feature, + batch_size=batch_size, + ) + logprobs_dataset.push_to_hub( + repo_id=out_repo_id, + split=utils.hf_split_to_split_name(split), + revision=revision, + ) + + +def get_logprobs_single_model( + model: AutoModelForCausalLM, + dataset: Dataset, + feature: str, + batch_size: int, +) -> Dataset: + n_seq = len(dataset) + seq_len = len(dataset[0][feature]) logprobs = np.empty((n_seq, seq_len)) logprobs[:, 0] = float("nan") print("Running inference...") for i in trange(0, n_seq, batch_size): - batch_tokens = in_dataset_split[i : i + batch_size][feature] + batch_tokens = dataset[i : i + batch_size][feature] logprobs[i : i + batch_size, 1:] = ( - get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() + get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() # type: ignore ) - - hf_dataset = Dataset.from_dict({"logprobs": [row for row in logprobs]}) - - hf_dataset.push_to_hub( - repo_id=out_repo_id, - split=utils.hf_split_to_split_name(split), - ) + return Dataset.from_dict({"logprobs": [row for row in logprobs]}) if __name__ == "__main__": @@ -60,6 +79,15 @@ def main( required=True, help="The model", ) + parser.add_argument( + "--revisions", + "-r", + help="comma separated revisions of the model to use or 'ALL_BRANCHES' to use all branches", + type=str, + default="main", + required=False, + ) + parser.add_argument( "--in-dataset-repo-id", "--id", @@ -105,8 +133,15 @@ def main( # ) args = parser.parse_args() + revisions = ( + args.revisions.split(",") + if args.revisions != "ALL_BRANCHES" + else utils.get_all_hf_branch_names(args.in_model_repo_id) + ) + main( in_model_repo_id=args.in_model_repo_id, + revisions=revisions, in_dataset_repo_id=args.in_dataset_repo_id, split=args.split, feature=args.feature, diff --git a/src/delphi/utils.py b/src/delphi/utils.py index 05721edc..0ceb059a 100644 --- a/src/delphi/utils.py +++ b/src/delphi/utils.py @@ -48,3 +48,11 @@ def load_dataset_split_sequence_int32_feature( split, Features({feature_name: Sequence(Value("int32"))}), ) + + +def get_all_hf_branch_names(repo_id: str) -> list[str]: + from huggingface_hub import HfApi + + api = HfApi() + refs = api.list_repo_refs(repo_id) + return [branch.name for branch in refs.branches] From 79ce78cdb7b8965b0fbeec336f25c566f7c323e4 Mon Sep 17 00:00:00 2001 From: Jett Date: Fri, 26 Apr 2024 17:05:13 +0200 Subject: [PATCH 3/3] revisions -> branches, tested --- scripts/get_next_logprobs.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/scripts/get_next_logprobs.py b/scripts/get_next_logprobs.py index 1ba0365c..5cf0d26e 100755 --- a/scripts/get_next_logprobs.py +++ b/scripts/get_next_logprobs.py @@ -16,7 +16,7 @@ def main( in_model_repo_id: str, - revisions: Iterable[str], + branches: Iterable[str], in_dataset_repo_id: str, split: str, feature: str, @@ -31,11 +31,9 @@ def main( in_dataset_repo_id, split, feature ) in_dataset_split.set_format("torch") - for revision in revisions: - print(f"Loading model={in_model_repo_id}, {revision=}") - model = AutoModelForCausalLM.from_pretrained( - in_model_repo_id, revision=revision - ) + for branch in branches: + print(f"Loading model='{in_model_repo_id}', {branch=}") + model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch) logprobs_dataset = get_logprobs_single_model( model=model, dataset=in_dataset_split, @@ -45,7 +43,7 @@ def main( logprobs_dataset.push_to_hub( repo_id=out_repo_id, split=utils.hf_split_to_split_name(split), - revision=revision, + revision=branch, ) @@ -80,9 +78,8 @@ def get_logprobs_single_model( help="The model", ) parser.add_argument( - "--revisions", - "-r", - help="comma separated revisions of the model to use or 'ALL_BRANCHES' to use all branches", + "--branches", + help="comma separated branches of the model to use or 'ALL' to use all branches", type=str, default="main", required=False, @@ -133,15 +130,15 @@ def get_logprobs_single_model( # ) args = parser.parse_args() - revisions = ( - args.revisions.split(",") - if args.revisions != "ALL_BRANCHES" + branches = ( + args.branches.split(",") + if args.branches != "ALL" else utils.get_all_hf_branch_names(args.in_model_repo_id) ) main( in_model_repo_id=args.in_model_repo_id, - revisions=revisions, + branches=branches, in_dataset_repo_id=args.in_dataset_repo_id, split=args.split, feature=args.feature,