Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_next_logprobs revamp #130

Merged
merged 3 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions scripts/generate_logprobs.sh

This file was deleted.

147 changes: 147 additions & 0 deletions scripts/get_next_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env python3
import argparse
from collections.abc import Iterable

import numpy as np
import torch
from datasets import Dataset
from tqdm.auto import trange
from transformers import AutoModelForCausalLM

from delphi import utils
from delphi.eval.utils import get_all_and_next_logprobs

torch.set_grad_enabled(False)


def main(
in_model_repo_id: str,
branches: Iterable[str],
in_dataset_repo_id: str,
split: str,
feature: str,
batch_size: int,
out_repo_id: str,
):
"""
Outputs the log probabilities of the next token for each token in the dataset.
And uploads the resulting dataset to huggingface.
"""
in_dataset_split = utils.load_dataset_split_sequence_int32_feature(
in_dataset_repo_id, split, feature
)
in_dataset_split.set_format("torch")
for branch in branches:
print(f"Loading model='{in_model_repo_id}', {branch=}")
model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch)
logprobs_dataset = get_logprobs_single_model(
model=model,
dataset=in_dataset_split,
feature=feature,
batch_size=batch_size,
)
logprobs_dataset.push_to_hub(
repo_id=out_repo_id,
split=utils.hf_split_to_split_name(split),
revision=branch,
)


def get_logprobs_single_model(
model: AutoModelForCausalLM,
dataset: Dataset,
feature: str,
batch_size: int,
) -> Dataset:
n_seq = len(dataset)
seq_len = len(dataset[0][feature])
logprobs = np.empty((n_seq, seq_len))
logprobs[:, 0] = float("nan")
print("Running inference...")
for i in trange(0, n_seq, batch_size):
batch_tokens = dataset[i : i + batch_size][feature]
logprobs[i : i + batch_size, 1:] = (
get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() # type: ignore
)
return Dataset.from_dict({"logprobs": [row for row in logprobs]})


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run inference and generate log probabilities."
)
parser.add_argument(
"--in-model-repo-id",
"--im",
type=str,
required=True,
help="The model",
)
parser.add_argument(
"--branches",
help="comma separated branches of the model to use or 'ALL' to use all branches",
type=str,
default="main",
required=False,
)

parser.add_argument(
"--in-dataset-repo-id",
"--id",
type=str,
required=True,
help="The tokenized dataset",
)
parser.add_argument(
"--feature",
"-f",
type=str,
required=True,
help="Name of the column containing token sequences in the input dataset",
)
parser.add_argument(
"--split",
"-s",
type=str,
required=True,
help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'",
)
parser.add_argument(
"--out-repo-id",
"-o",
type=str,
required=True,
help="Where to upload the next logprobs",
)
parser.add_argument(
"--batch-size",
"-b",
type=int,
default=80,
help="How many sequences to evaluate at once",
)
# TODO
# parser.add_argument(
# "--chunk-size",
# "-c",
# type=int,
# default=200_000,
# help="Size of the parquet chunks uploaded to HuggingFace",
# )
args = parser.parse_args()

branches = (
args.branches.split(",")
if args.branches != "ALL"
else utils.get_all_hf_branch_names(args.in_model_repo_id)
)

main(
in_model_repo_id=args.in_model_repo_id,
branches=branches,
in_dataset_repo_id=args.in_dataset_repo_id,
split=args.split,
feature=args.feature,
batch_size=args.batch_size,
out_repo_id=args.out_repo_id,
)
105 changes: 0 additions & 105 deletions scripts/inference.py

This file was deleted.

58 changes: 58 additions & 0 deletions src/delphi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import cast

from datasets import Dataset, Features, Sequence, Value, load_dataset


def hf_split_to_split_name(split: str) -> str:
return split.split("[")[0]


# TODO: test load_dataset functions
def load_dataset_split_features(
repo_id: str,
split: str,
features: Features,
) -> Dataset:
dataset = load_dataset(
repo_id,
split=split,
features=features,
)
dataset = cast(Dataset, dataset)
return dataset


def load_dataset_split_string_feature(
repo_id: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading string dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
split,
Features({feature_name: Value("string")}),
)


def load_dataset_split_sequence_int32_feature(
repo_id: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading sequence int32 dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
split,
Features({feature_name: Sequence(Value("int32"))}),
)


def get_all_hf_branch_names(repo_id: str) -> list[str]:
from huggingface_hub import HfApi

api = HfApi()
refs = api.list_repo_refs(repo_id)
return [branch.name for branch in refs.branches]
Empty file added tests/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from delphi.utils import hf_split_to_split_name

from .utils import random_string


def test_hf_split_to_split_name():
random_split_name = random_string(5)
assert hf_split_to_split_name(random_split_name) == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import random
import string


def random_string(length: int) -> str:
return "".join(random.choices(string.ascii_lowercase, k=length))
Loading