Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility functions for text processing and visualization #17

Merged
merged 18 commits into from
Feb 13, 2024
24 changes: 12 additions & 12 deletions notebooks/vis_demo.ipynb

Large diffs are not rendered by default.

42 changes: 12 additions & 30 deletions src/delphi/eval/compare_models.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from dataclasses import dataclass
from typing import cast

import torch
import torch.nn as nn
from jaxtyping import Int
from transformers import PreTrainedModel

from delphi.eval.utils import get_correct_and_all_probs
from delphi.eval.utils import get_all_and_next_logprobs_single


@dataclass
class ModelId:
model_name: str


def identify_model(model: PreTrainedModel) -> ModelId:
return ModelId(model_name=model.config.name_or_path)
def identify_model(model: PreTrainedModel) -> str:
return model.config.name_or_path


@dataclass
Expand All @@ -27,28 +20,18 @@ class TokenPrediction:

@dataclass
class NextTokenStats:
base_model: ModelId
lift_model: ModelId
base_model: str
lift_model: str
next_prediction: TokenPrediction
topk: list[TokenPrediction]


def _pad_start(tensor: torch.Tensor) -> torch.Tensor:
value_to_prepend = -1
if len(tensor.shape) == 1:
return torch.cat((torch.tensor([value_to_prepend]), tensor))
else:
# input: 2D tensor of shape [seq_len - 1, top_k]
pre = torch.full((1, tensor.size()[-1]), value_to_prepend)
return torch.cat((pre, tensor), dim=0)


def compare_models(
model_a: PreTrainedModel,
model_b: PreTrainedModel,
sample_tok: Int[torch.Tensor, "seq"],
top_k: int = 3,
) -> list[NextTokenStats]:
) -> list[NextTokenStats | None]:
"""
Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
Args:
Expand All @@ -67,19 +50,18 @@ def compare_models(
device = model_a.device
sample_tok = sample_tok.to(device)

next_probs_a, probs_a = get_correct_and_all_probs(model_a, sample_tok)
next_probs_b, probs_b = get_correct_and_all_probs(model_b, sample_tok)
probs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
probs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this probs or logprobs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

! Good catch!


top_k_b = torch.topk(probs_b, top_k, dim=-1)
top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)

next_probs_a = _pad_start(next_probs_a)
next_probs_b = _pad_start(next_probs_b)
top_k_b_tokens = _pad_start(top_k_b.indices)
top_k_a_probs = _pad_start(top_k_a_probs)
top_k_b_probs = _pad_start(top_k_b.values)
top_k_b_tokens = top_k_b.indices
top_k_b_probs = top_k_b.values

comparisons = []
# ignore first token when evaluating predictions
comparisons.append(None)

for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
Expand Down
70 changes: 32 additions & 38 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from collections.abc import Callable
from typing import List, cast
from typing import cast

import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
from transformers import PreTrainedModel, PreTrainedTokenizerBase

ALLOWED_CHARS = set(
" \n\"'(),.:?!0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
)


def get_all_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
Expand All @@ -33,12 +29,39 @@ def gather_logprobs(
return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)


def get_next_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
) -> Float[torch.Tensor, "batch shorter_seq"]:
def get_all_and_next_logprobs(
model: Callable,
input_ids: Int[torch.Tensor, "batch seq"],
) -> tuple[
Float[torch.Tensor, "batch shorter_seq vocab"],
Float[torch.Tensor, "batch shorter_seq"],
]:
logprobs = get_all_logprobs(model, input_ids[:, :-1])
next_tokens = input_ids[:, 1:]
return gather_logprobs(logprobs, next_tokens)
return logprobs, gather_logprobs(logprobs, next_tokens)


def get_all_and_next_logprobs_single(
model: Callable,
input_ids: Int[torch.Tensor, "seq"],
) -> tuple[
Float[torch.Tensor, "shorter_seq vocab"],
Float[torch.Tensor, "shorter_seq"],
]:
all_logprobs, next_logprobs = get_all_and_next_logprobs(
model, input_ids.unsqueeze(0)
)
return all_logprobs[0], next_logprobs[0]


def get_next_and_top_k_probs(
model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]:
all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids)
all_probs = torch.exp(all_logprobs)
next_probs = torch.exp(next_logprobs)
top_k = torch.topk(all_probs, k, dim=-1)
return next_probs, top_k
Comment on lines +57 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the demo notebook



def load_validation_dataset(dataset_name: str) -> Dataset:
Expand All @@ -56,16 +79,6 @@ def load_validation_dataset(dataset_name: str) -> Dataset:
return cast(Dataset, dataset)


def load_text_from_dataset(dataset: Dataset) -> list[str]:
text = []
for sample_txt in dataset["story"]:
# encoding issues and rare weird prompts
if not set(sample_txt).issubset(ALLOWED_CHARS):
continue
text.append(sample_txt)
return text


def tokenize(
tokenizer: PreTrainedTokenizerBase, sample_txt: str
) -> Int[torch.Tensor, "seq"]:
Expand All @@ -74,22 +87,3 @@ def tokenize(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)


def get_correct_and_all_probs(
model: PreTrainedModel, sample_tok: Int[torch.Tensor, "seq"]
) -> tuple[Float[torch.Tensor, "next_seq"], Float[torch.Tensor, "next_seq vocab"]]:
"""Get probabilities for the actual next token and for all predictions"""
# remove the first token (the bos token)
probs = get_single_logprobs(model, sample_tok)[1:]
correct_probs = probs[range(len(probs)), sample_tok[1:]]
return correct_probs, probs


def get_correct_and_top_probs(
model: PreTrainedModel, sample_tok: Int[torch.Tensor, "seq"], top_k: int = 3
) -> tuple[Float[torch.Tensor, "next_seq"], torch.return_types.topk]:
"""Get probabilities for the actual next token and for top k predictions"""
correct_probs, probs = get_correct_and_all_probs(model, sample_tok)
top_k_probs = torch.topk(probs, top_k, dim=-1)
return correct_probs, top_k_probs
1 change: 0 additions & 1 deletion src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def token_to_html(
_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])


# TODO: basic unit test for visualizations w/ Selenium
def vis_sample_prediction_probs(
sample_tok: Int[torch.Tensor, "pos"],
correct_probs: Float[torch.Tensor, "pos"],
Expand Down
23 changes: 23 additions & 0 deletions tests/eval/test_compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from delphi.eval.compare_models import NextTokenStats, compare_models
from delphi.eval.utils import load_validation_dataset, tokenize


def test_compare_models():
with torch.set_grad_enabled(False):
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
model_instruct = AutoModelForCausalLM.from_pretrained(
"roneneldan/TinyStories-Instruct-1M"
)
ds_txt = load_validation_dataset("tinystories-v2-clean")["story"]
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
sample_tok = tokenize(tokenizer, ds_txt[0])
K = 3
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
# ignore the first element comparison
assert model_comparison[0] is None
assert isinstance(model_comparison[1], NextTokenStats)
assert len(model_comparison) == sample_tok.shape[0]
assert len(model_comparison[1].topk) == K
55 changes: 0 additions & 55 deletions tests/test_vis.py
jaidhyani marked this conversation as resolved.
Outdated
Show resolved Hide resolved

This file was deleted.