Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility functions for text processing and visualization #17

Merged
merged 18 commits into from
Feb 13, 2024
6 changes: 3 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main
pull_request:
branches:
- '*'
- "*"

permissions:
actions: write
Expand All @@ -29,6 +29,6 @@ jobs:
- name: black
run: black --check .
- name: isort
run: isort --check .
run: isort --profile black --check .
- name: pytest
run: pytest
run: pytest
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ repos:
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
name: isort (python)
args: ["--profile", "black"]
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
"source.organizeImports": "explicit"
},
"python.analysis.typeCheckingMode": "basic",
"isort.args": [
"--profile black"
],
}
148 changes: 148 additions & 0 deletions notebooks/vis_demo.ipynb

Large diffs are not rendered by default.

106 changes: 106 additions & 0 deletions src/delphi/eval/compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from dataclasses import dataclass
from typing import cast

import torch
import torch.nn as nn
from jaxtyping import Int
from transformers import PreTrainedModel

from delphi.eval.utils import get_correct_and_all_probs


@dataclass
class ModelId:
model_name: str
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved


def identify_model(model: PreTrainedModel) -> ModelId:
return ModelId(model_name=model.config.name_or_path)


@dataclass
class TokenPrediction:
token: int
base_model_prob: float
lift_model_prob: float


@dataclass
class NextTokenStats:
base_model: ModelId
lift_model: ModelId
next_prediction: TokenPrediction
topk: list[TokenPrediction]


def _pad_start(tensor: torch.Tensor) -> torch.Tensor:
value_to_prepend = -1
if len(tensor.shape) == 1:
return torch.cat((torch.tensor([value_to_prepend]), tensor))
else:
# input: 2D tensor of shape [seq_len - 1, top_k]
pre = torch.full((1, tensor.size()[-1]), value_to_prepend)
return torch.cat((pre, tensor), dim=0)
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved


def compare_models(
model_a: PreTrainedModel,
model_b: PreTrainedModel,
sample_tok: Int[torch.Tensor, "seq"],
top_k: int = 3,
) -> list[NextTokenStats]:
"""
Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
Args:
- model_a: The first model (assumed to be the base model)
- model_b: The second model (assumed to be the improved model)
- sample_tok: The tokenized prompt
- top_k: The number of top token predictions to retrieve (default is 5)
Returns:
A list of NextTokenStats objects, one for each token in the prompt.
Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
"""
assert (
model_a.device == model_b.device
), "Both models must be on the same device for comparison."

device = model_a.device
sample_tok = sample_tok.to(device)

next_probs_a, probs_a = get_correct_and_all_probs(model_a, sample_tok)
next_probs_b, probs_b = get_correct_and_all_probs(model_b, sample_tok)

top_k_b = torch.topk(probs_b, top_k, dim=-1)
top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)

next_probs_a = _pad_start(next_probs_a)
next_probs_b = _pad_start(next_probs_b)
top_k_b_tokens = _pad_start(top_k_b.indices)
top_k_a_probs = _pad_start(top_k_a_probs)
top_k_b_probs = _pad_start(top_k_b.values)

comparisons = []
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved

for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
):
nts = NextTokenStats(
base_model=identify_model(model_a),
lift_model=identify_model(model_b),
next_prediction=TokenPrediction(
token=int(next_p_a.item()),
base_model_prob=next_p_a.item(),
lift_model_prob=next_p_b.item(),
),
topk=[
TokenPrediction(
token=int(top_toks_b[i].item()),
base_model_prob=top_probs_a[i].item(),
lift_model_prob=top_probs_b[i].item(),
)
for i in range(top_k)
],
)
comparisons.append(nts)

return comparisons
53 changes: 52 additions & 1 deletion src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections.abc import Callable
from typing import cast
from typing import List, cast
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved

import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
from transformers import PreTrainedModel, PreTrainedTokenizerBase

ALLOWED_CHARS = set(
" \n\"'(),.:?!0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
)


def get_all_logprobs(
Expand All @@ -14,6 +19,13 @@ def get_all_logprobs(
return torch.log_softmax(logits, dim=-1)


# convenience wrapper for calling on a single sample
def get_single_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "seq"]
) -> Float[torch.Tensor, "seq vocab"]:
return get_all_logprobs(model, input_ids.unsqueeze(0))[0]


def gather_logprobs(
logprobs: Float[torch.Tensor, "batch seq vocab"],
tokens: Int[torch.Tensor, "batch seq"],
Expand Down Expand Up @@ -42,3 +54,42 @@ def load_validation_dataset(dataset_name: str) -> Dataset:
split="train",
)
return cast(Dataset, dataset)


def load_text_from_dataset(dataset: Dataset) -> list[str]:
text = []
for sample_txt in dataset["story"]:
# encoding issues and rare weird prompts
if not set(sample_txt).issubset(ALLOWED_CHARS):
continue
text.append(sample_txt)
return text
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved


def tokenize(
tokenizer: PreTrainedTokenizerBase, sample_txt: str
) -> Int[torch.Tensor, "seq"]:
# supposedly this can be different than prepending the bos token id
return cast(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)


def get_correct_and_all_probs(
model: PreTrainedModel, sample_tok: Int[torch.Tensor, "seq"]
) -> tuple[Float[torch.Tensor, "next_seq"], Float[torch.Tensor, "next_seq vocab"]]:
"""Get probabilities for the actual next token and for all predictions"""
# remove the first token (the bos token)
probs = get_single_logprobs(model, sample_tok)[1:]
correct_probs = probs[range(len(probs)), sample_tok[1:]]
return correct_probs, probs


def get_correct_and_top_probs(
model: PreTrainedModel, sample_tok: Int[torch.Tensor, "seq"], top_k: int = 3
) -> tuple[Float[torch.Tensor, "next_seq"], torch.return_types.topk]:
"""Get probabilities for the actual next token and for top k predictions"""
correct_probs, probs = get_correct_and_all_probs(model, sample_tok)
top_k_probs = torch.topk(probs, top_k, dim=-1)
return correct_probs, top_k_probs
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
141 changes: 141 additions & 0 deletions src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import uuid
from typing import cast

import torch
from IPython.core.display import HTML
from IPython.core.display_functions import display
from jaxtyping import Float, Int
from transformers import PreTrainedTokenizerBase


def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]:
# for the endoftext token
# no prediction, no color
colors = ["white"]
for p in probs.tolist():
red_gap = 150 # the higher it is, the less red the tokens will be
green_blue_val = red_gap + int((255 - red_gap) * (1 - p))
colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})")
return colors


def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str:
tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n")
prob_str = f"{prob:.2%}"
return f"{prob_str:>6} |{tok_str}|"


def token_to_html(
token: int,
tokenizer: PreTrainedTokenizerBase,
bg_color: str,
data: dict,
) -> str:
data = data or {} # equivalent to if not data: data = {}
# non-breakable space, w/o it leading spaces wouldn't be displayed
str_token = tokenizer.decode(token).replace(" ", " ")

# background or user-select (for \n) goes here
specific_styles = {}
# for now just adds line break or doesn't
br = ""

if bg_color:
specific_styles["background-color"] = bg_color
if str_token == "\n":
# replace new line character with two characters: \ and n
str_token = r"\n"
# add line break in html
br += "<br>"
# this is so we can copy the prompt without "\n"s
specific_styles["user-select"] = "none"

style_str = data_str = ""
# converting style dict into the style attribute
if specific_styles:
inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
style_str = f" style='{inside_style_str}'"
if data:
data_str = "".join(
f" data-{k}='{v.replace(' ', '&nbsp;')}'" for k, v in data.items()
)
return f"<div class='token'{style_str}{data_str}>{str_token}</div>{br}"


_token_style = {
"border": "1px solid #888",
"display": "inline-block",
# each character of the same width, so we can easily spot a space
"font-family": "monospace",
"font-size": "14px",
"color": "black",
"background-color": "white",
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])


# TODO: basic unit test for visualizations w/ Selenium
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
def vis_sample_prediction_probs(
sample_tok: Int[torch.Tensor, "pos"],
correct_probs: Float[torch.Tensor, "pos"],
top_k_probs: torch.return_types.topk,
tokenizer: PreTrainedTokenizerBase,
) -> str:
colors = probs_to_colors(correct_probs)
token_htmls = []

# Generate a unique ID for this instance (so we can have multiple instances on the same page)
unique_id = str(uuid.uuid4())

token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"

for i in range(sample_tok.shape[0]):
tok = cast(int, sample_tok[i].item())
data = {}
if i > 0:
correct_prob = correct_probs[i - 1].item()
data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer)
top_k_probs_tokens = top_k_probs.indices[i - 1]
top_k_probs_values = top_k_probs.values[i - 1]
for j in range(top_k_probs_tokens.shape[0]):
top_tok = top_k_probs_tokens[j].item()
top_tok = cast(int, top_tok)
top_prob = top_k_probs_values[j].item()
data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)

token_htmls.append(
token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
"class='token'", f"class='{token_class}'"
)
)

html_str = f"""
<style>.{token_class} {{ {_token_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
{"".join(token_htmls)} <div id='{hover_div_id}'></div>
<script>
(function() {{
var token_divs = document.querySelectorAll('.{token_class}');
var hover_info = document.getElementById('{hover_div_id}');


token_divs.forEach(function(token_div) {{
token_div.addEventListener('mousemove', function(e) {{
hover_info.innerHTML = ""
for( var d in this.dataset) {{
hover_info.innerHTML += "<b>" + d + "</b> ";
hover_info.innerHTML += this.dataset[d] + "<br>";
}}
}});

token_div.addEventListener('mouseout', function(e) {{
hover_info.innerHTML = ""
}});
}});
}})();
</script>
"""
display(HTML(html_str))
return html_str
Loading