diff --git a/scripts/generate_logprobs.sh b/scripts/generate_logprobs.sh deleted file mode 100644 index fc1f836a..00000000 --- a/scripts/generate_logprobs.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# Define the batch size -BATCH_SIZE=80 # This worked well in my CPU, but 200 was too much -DATASET_NAME="delphi-suite/tinystories-v2-clean-tokenized" -USERNAME="transcendingvictor" # your Hugging Face username -TOKEN="hf_aaaaaaaaaaaaaaaaaaaaaaaaaa" # your Hugging Face API token - - -# List of models -declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k" - "delphi-suite/delphi-llama2-200k" - "delphi-suite/delphi-llama2-400k" - "delphi-suite/delphi-llama2-800k" - "delphi-suite/delphi-llama2-1.6m" - "delphi-suite/delphi-llama2-3.2m" - "delphi-suite/delphi-llama2-6.4m" - "delphi-suite/delphi-llama2-12.8m" - "delphi-suite/delphi-llama2-25.6m") - -# Loop through each model and generate log probabilities -for MODEL_NAME in "${MODEL_NAMES[@]}" -do - echo "Processing $MODEL_NAME" - python scripts/inference.py "$MODEL_NAME" --batch-size "$BATCH_SIZE" --dataset-name "$DATASET_NAME" --username "$USERNAME" --token "$TOKEN" -done - -echo "All models processed." diff --git a/scripts/get_next_logprobs.py b/scripts/get_next_logprobs.py new file mode 100755 index 00000000..5cf0d26e --- /dev/null +++ b/scripts/get_next_logprobs.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +import argparse +from collections.abc import Iterable + +import numpy as np +import torch +from datasets import Dataset +from tqdm.auto import trange +from transformers import AutoModelForCausalLM + +from delphi import utils +from delphi.eval.utils import get_all_and_next_logprobs + +torch.set_grad_enabled(False) + + +def main( + in_model_repo_id: str, + branches: Iterable[str], + in_dataset_repo_id: str, + split: str, + feature: str, + batch_size: int, + out_repo_id: str, +): + """ + Outputs the log probabilities of the next token for each token in the dataset. + And uploads the resulting dataset to huggingface. + """ + in_dataset_split = utils.load_dataset_split_sequence_int32_feature( + in_dataset_repo_id, split, feature + ) + in_dataset_split.set_format("torch") + for branch in branches: + print(f"Loading model='{in_model_repo_id}', {branch=}") + model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch) + logprobs_dataset = get_logprobs_single_model( + model=model, + dataset=in_dataset_split, + feature=feature, + batch_size=batch_size, + ) + logprobs_dataset.push_to_hub( + repo_id=out_repo_id, + split=utils.hf_split_to_split_name(split), + revision=branch, + ) + + +def get_logprobs_single_model( + model: AutoModelForCausalLM, + dataset: Dataset, + feature: str, + batch_size: int, +) -> Dataset: + n_seq = len(dataset) + seq_len = len(dataset[0][feature]) + logprobs = np.empty((n_seq, seq_len)) + logprobs[:, 0] = float("nan") + print("Running inference...") + for i in trange(0, n_seq, batch_size): + batch_tokens = dataset[i : i + batch_size][feature] + logprobs[i : i + batch_size, 1:] = ( + get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() # type: ignore + ) + return Dataset.from_dict({"logprobs": [row for row in logprobs]}) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run inference and generate log probabilities." + ) + parser.add_argument( + "--in-model-repo-id", + "--im", + type=str, + required=True, + help="The model", + ) + parser.add_argument( + "--branches", + help="comma separated branches of the model to use or 'ALL' to use all branches", + type=str, + default="main", + required=False, + ) + + parser.add_argument( + "--in-dataset-repo-id", + "--id", + type=str, + required=True, + help="The tokenized dataset", + ) + parser.add_argument( + "--feature", + "-f", + type=str, + required=True, + help="Name of the column containing token sequences in the input dataset", + ) + parser.add_argument( + "--split", + "-s", + type=str, + required=True, + help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'", + ) + parser.add_argument( + "--out-repo-id", + "-o", + type=str, + required=True, + help="Where to upload the next logprobs", + ) + parser.add_argument( + "--batch-size", + "-b", + type=int, + default=80, + help="How many sequences to evaluate at once", + ) + # TODO + # parser.add_argument( + # "--chunk-size", + # "-c", + # type=int, + # default=200_000, + # help="Size of the parquet chunks uploaded to HuggingFace", + # ) + args = parser.parse_args() + + branches = ( + args.branches.split(",") + if args.branches != "ALL" + else utils.get_all_hf_branch_names(args.in_model_repo_id) + ) + + main( + in_model_repo_id=args.in_model_repo_id, + branches=branches, + in_dataset_repo_id=args.in_dataset_repo_id, + split=args.split, + feature=args.feature, + batch_size=args.batch_size, + out_repo_id=args.out_repo_id, + ) diff --git a/scripts/inference.py b/scripts/inference.py deleted file mode 100644 index abe9b48b..00000000 --- a/scripts/inference.py +++ /dev/null @@ -1,105 +0,0 @@ -import argparse -import os - -import numpy as np -import pandas as pd -import torch -from datasets import Dataset, load_dataset -from jaxtyping import Int -from tqdm.auto import tqdm -from transformers import AutoModelForCausalLM - -from delphi.eval.utils import get_all_and_next_logprobs, load_validation_dataset - -torch.set_grad_enabled(False) - - -def main( - model_name: str, - batch_size: Int, - dataset_name: str, - username: str, - funct_test: bool = False, -): - """ - Outputs the log probabilities of the next token for each token in the validation dataset. - And uploads the resulting dataset to huggingface. - Args: - - model_name: The name of the model to use for inference - - batch_size: The batch size for processing. 80 worked well in CPU. - - dataset_name: The name of the dataset from which validation set will be loaded - - username: Hugging Face API username - """ - val_ds = load_validation_dataset(dataset_name) - - model = AutoModelForCausalLM.from_pretrained(model_name) - - total_sequences = ( - len(val_ds) if not funct_test else 320 - ) # Use only 320 sequences if funct_test is True - - logprobs = np.empty((total_sequences, 513)) - logprobs[:, 0] = float("nan") - for i in tqdm(range(0, total_sequences, batch_size)): - batch_end = min(i + batch_size, total_sequences) - batch_sequences = [val_ds[j]["tokens"] for j in range(i, batch_end)] - batch_sequences_tensor = torch.tensor(batch_sequences) - - logprobs_tensor = get_all_and_next_logprobs(model, batch_sequences_tensor)[1] - logprobs[i:batch_end, 1:] = logprobs_tensor.cpu().numpy() - - df_dataset = pd.DataFrame({"logprobs": [row for row in logprobs]}) - hf_dataset = Dataset.from_pandas(df_dataset) - - # change the repo_id to your hf username in generate_logprobs.sh - # change the yout hf token in generate_logprobs.sh - - repo_id = f"{username}/{model_name.rsplit('/', 1)[-1]}-validation-logprobs" - if funct_test: - repo_id += "-funct-test" - hf_dataset.push_to_hub( - repo_id=repo_id, - split="validation", - private=False, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run inference and generate log probabilities." - ) - parser.add_argument( - "model_name", type=str, help="Model name with or without delphi-suite/ prefix" - ) - parser.add_argument( - "--batch-size", - type=int, - default=80, - help="Batch size for processing (default: 80)", - ) - parser.add_argument( - "--dataset-name", - type=str, - help="Dataset name with or without delphi-suite/ prefix", - ) - parser.add_argument( - "--username", - type=str, - help="Hugging Face API username", - ) - parser.add_argument( - "--test-funct", action="store_true", help="Enable test function mode" - ) - - args = parser.parse_args() - - if "/" not in args.model_name: - args.model_name = "delphi-suite/" + args.model_name - - main( - args.model_name, - args.batch_size, - args.dataset_name, - args.username, - args.test_funct, - ) diff --git a/src/delphi/utils.py b/src/delphi/utils.py new file mode 100644 index 00000000..0ceb059a --- /dev/null +++ b/src/delphi/utils.py @@ -0,0 +1,58 @@ +from typing import cast + +from datasets import Dataset, Features, Sequence, Value, load_dataset + + +def hf_split_to_split_name(split: str) -> str: + return split.split("[")[0] + + +# TODO: test load_dataset functions +def load_dataset_split_features( + repo_id: str, + split: str, + features: Features, +) -> Dataset: + dataset = load_dataset( + repo_id, + split=split, + features=features, + ) + dataset = cast(Dataset, dataset) + return dataset + + +def load_dataset_split_string_feature( + repo_id: str, + split: str, + feature_name: str, +) -> Dataset: + print("Loading string dataset") + print(f"{repo_id=}, {split=}, {feature_name=}") + return load_dataset_split_features( + repo_id, + split, + Features({feature_name: Value("string")}), + ) + + +def load_dataset_split_sequence_int32_feature( + repo_id: str, + split: str, + feature_name: str, +) -> Dataset: + print("Loading sequence int32 dataset") + print(f"{repo_id=}, {split=}, {feature_name=}") + return load_dataset_split_features( + repo_id, + split, + Features({feature_name: Sequence(Value("int32"))}), + ) + + +def get_all_hf_branch_names(repo_id: str) -> list[str]: + from huggingface_hub import HfApi + + api = HfApi() + refs = api.list_repo_refs(repo_id) + return [branch.name for branch in refs.branches] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..597438ca --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,14 @@ +from delphi.utils import hf_split_to_split_name + +from .utils import random_string + + +def test_hf_split_to_split_name(): + random_split_name = random_string(5) + assert hf_split_to_split_name(random_split_name) == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name + assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..ed81b58a --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,6 @@ +import random +import string + + +def random_string(length: int) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length))