Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility functions for text processing and visualization #17

Merged
merged 18 commits into from
Feb 13, 2024
6 changes: 3 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main
pull_request:
branches:
- '*'
- "*"

permissions:
actions: write
Expand All @@ -29,6 +29,6 @@ jobs:
- name: black
run: black --check .
- name: isort
run: isort --check .
run: isort --profile black --check .
- name: pytest
run: pytest
run: pytest
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ repos:
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
name: isort (python)
args: ["--profile", "black"]
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
"source.organizeImports": "explicit"
},
"python.analysis.typeCheckingMode": "basic",
"isort.args": [
"--profile black"
],
}
148 changes: 148 additions & 0 deletions notebooks/vis_demo.ipynb

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions src/delphi/eval/compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from dataclasses import dataclass

import torch
from jaxtyping import Int
from transformers import PreTrainedModel

from delphi.eval.utils import get_all_and_next_logprobs_single


def identify_model(model: PreTrainedModel) -> str:
return model.config.name_or_path


@dataclass
class TokenPrediction:
token: int
base_model_prob: float
lift_model_prob: float


@dataclass
class NextTokenStats:
base_model: str
lift_model: str
next_prediction: TokenPrediction
topk: list[TokenPrediction]


def compare_models(
model_a: PreTrainedModel,
model_b: PreTrainedModel,
sample_tok: Int[torch.Tensor, "seq"],
top_k: int = 3,
) -> list[NextTokenStats | None]:
"""
Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
Args:
- model_a: The first model (assumed to be the base model)
- model_b: The second model (assumed to be the improved model)
- sample_tok: The tokenized prompt
- top_k: The number of top token predictions to retrieve (default is 5)
Returns:
A list of NextTokenStats objects, one for each token in the prompt.
Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
"""
assert (
model_a.device == model_b.device
), "Both models must be on the same device for comparison."

device = model_a.device
sample_tok = sample_tok.to(device)

probs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
probs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)

top_k_b = torch.topk(probs_b, top_k, dim=-1)
top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)

top_k_b_tokens = top_k_b.indices
top_k_b_probs = top_k_b.values

comparisons = []
# ignore first token when evaluating predictions
comparisons.append(None)

for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
):
nts = NextTokenStats(
base_model=identify_model(model_a),
lift_model=identify_model(model_b),
next_prediction=TokenPrediction(
token=int(next_p_a.item()),
base_model_prob=next_p_a.item(),
lift_model_prob=next_p_b.item(),
),
topk=[
TokenPrediction(
token=int(top_toks_b[i].item()),
base_model_prob=top_probs_a[i].item(),
lift_model_prob=top_probs_b[i].item(),
)
for i in range(top_k)
],
)
comparisons.append(nts)

return comparisons
53 changes: 49 additions & 4 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
from transformers import PreTrainedModel, PreTrainedTokenizerBase


def get_all_logprobs(
Expand All @@ -14,19 +15,53 @@ def get_all_logprobs(
return torch.log_softmax(logits, dim=-1)


# convenience wrapper for calling on a single sample
def get_single_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "seq"]
) -> Float[torch.Tensor, "seq vocab"]:
return get_all_logprobs(model, input_ids.unsqueeze(0))[0]


def gather_logprobs(
logprobs: Float[torch.Tensor, "batch seq vocab"],
tokens: Int[torch.Tensor, "batch seq"],
) -> Float[torch.Tensor, "batch seq"]:
return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)


def get_next_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
) -> Float[torch.Tensor, "batch shorter_seq"]:
def get_all_and_next_logprobs(
model: Callable,
input_ids: Int[torch.Tensor, "batch seq"],
) -> tuple[
Float[torch.Tensor, "batch shorter_seq vocab"],
Float[torch.Tensor, "batch shorter_seq"],
]:
logprobs = get_all_logprobs(model, input_ids[:, :-1])
next_tokens = input_ids[:, 1:]
return gather_logprobs(logprobs, next_tokens)
return logprobs, gather_logprobs(logprobs, next_tokens)


def get_all_and_next_logprobs_single(
model: Callable,
input_ids: Int[torch.Tensor, "seq"],
) -> tuple[
Float[torch.Tensor, "shorter_seq vocab"],
Float[torch.Tensor, "shorter_seq"],
]:
all_logprobs, next_logprobs = get_all_and_next_logprobs(
model, input_ids.unsqueeze(0)
)
return all_logprobs[0], next_logprobs[0]


def get_next_and_top_k_probs(
model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]:
all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids)
all_probs = torch.exp(all_logprobs)
next_probs = torch.exp(next_logprobs)
top_k = torch.topk(all_probs, k, dim=-1)
return next_probs, top_k
Comment on lines +57 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the demo notebook



def load_validation_dataset(dataset_name: str) -> Dataset:
Expand All @@ -42,3 +77,13 @@ def load_validation_dataset(dataset_name: str) -> Dataset:
split="train",
)
return cast(Dataset, dataset)


def tokenize(
tokenizer: PreTrainedTokenizerBase, sample_txt: str
) -> Int[torch.Tensor, "seq"]:
# supposedly this can be different than prepending the bos token id
return cast(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)
140 changes: 140 additions & 0 deletions src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import uuid
from typing import cast

import torch
from IPython.core.display import HTML
from IPython.core.display_functions import display
from jaxtyping import Float, Int
from transformers import PreTrainedTokenizerBase


def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]:
# for the endoftext token
# no prediction, no color
colors = ["white"]
for p in probs.tolist():
red_gap = 150 # the higher it is, the less red the tokens will be
green_blue_val = red_gap + int((255 - red_gap) * (1 - p))
colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})")
return colors


def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str:
tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n")
prob_str = f"{prob:.2%}"
return f"{prob_str:>6} |{tok_str}|"


def token_to_html(
token: int,
tokenizer: PreTrainedTokenizerBase,
bg_color: str,
data: dict,
) -> str:
data = data or {} # equivalent to if not data: data = {}
# non-breakable space, w/o it leading spaces wouldn't be displayed
str_token = tokenizer.decode(token).replace(" ", " ")

# background or user-select (for \n) goes here
specific_styles = {}
# for now just adds line break or doesn't
br = ""

if bg_color:
specific_styles["background-color"] = bg_color
if str_token == "\n":
# replace new line character with two characters: \ and n
str_token = r"\n"
# add line break in html
br += "<br>"
# this is so we can copy the prompt without "\n"s
specific_styles["user-select"] = "none"

style_str = data_str = ""
# converting style dict into the style attribute
if specific_styles:
inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
style_str = f" style='{inside_style_str}'"
if data:
data_str = "".join(
f" data-{k}='{v.replace(' ', '&nbsp;')}'" for k, v in data.items()
)
return f"<div class='token'{style_str}{data_str}>{str_token}</div>{br}"


_token_style = {
"border": "1px solid #888",
"display": "inline-block",
# each character of the same width, so we can easily spot a space
"font-family": "monospace",
"font-size": "14px",
"color": "black",
"background-color": "white",
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])


def vis_sample_prediction_probs(
sample_tok: Int[torch.Tensor, "pos"],
correct_probs: Float[torch.Tensor, "pos"],
top_k_probs: torch.return_types.topk,
tokenizer: PreTrainedTokenizerBase,
) -> str:
colors = probs_to_colors(correct_probs)
token_htmls = []

# Generate a unique ID for this instance (so we can have multiple instances on the same page)
unique_id = str(uuid.uuid4())

token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"

for i in range(sample_tok.shape[0]):
tok = cast(int, sample_tok[i].item())
data = {}
if i > 0:
correct_prob = correct_probs[i - 1].item()
data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer)
top_k_probs_tokens = top_k_probs.indices[i - 1]
top_k_probs_values = top_k_probs.values[i - 1]
for j in range(top_k_probs_tokens.shape[0]):
top_tok = top_k_probs_tokens[j].item()
top_tok = cast(int, top_tok)
top_prob = top_k_probs_values[j].item()
data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)

token_htmls.append(
token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
"class='token'", f"class='{token_class}'"
)
)

html_str = f"""
<style>.{token_class} {{ {_token_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
{"".join(token_htmls)} <div id='{hover_div_id}'></div>
<script>
(function() {{
var token_divs = document.querySelectorAll('.{token_class}');
var hover_info = document.getElementById('{hover_div_id}');


token_divs.forEach(function(token_div) {{
token_div.addEventListener('mousemove', function(e) {{
hover_info.innerHTML = ""
for( var d in this.dataset) {{
hover_info.innerHTML += "<b>" + d + "</b> ";
hover_info.innerHTML += this.dataset[d] + "<br>";
}}
}});

token_div.addEventListener('mouseout', function(e) {{
hover_info.innerHTML = ""
}});
}});
}})();
</script>
"""
display(HTML(html_str))
return html_str
23 changes: 23 additions & 0 deletions tests/eval/test_compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from delphi.eval.compare_models import NextTokenStats, compare_models
from delphi.eval.utils import load_validation_dataset, tokenize


def test_compare_models():
with torch.set_grad_enabled(False):
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
model_instruct = AutoModelForCausalLM.from_pretrained(
"roneneldan/TinyStories-Instruct-1M"
)
ds_txt = load_validation_dataset("tinystories-v2-clean")["story"]
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
sample_tok = tokenize(tokenizer, ds_txt[0])
K = 3
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
# ignore the first element comparison
assert model_comparison[0] is None
assert isinstance(model_comparison[1], NextTokenStats)
assert len(model_comparison) == sample_tok.shape[0]
assert len(model_comparison[1].topk) == K