diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml
index e20ccd24..89a889fa 100644
--- a/.github/workflows/checks.yml
+++ b/.github/workflows/checks.yml
@@ -6,7 +6,7 @@ on:
- main
pull_request:
branches:
- - '*'
+ - "*"
permissions:
actions: write
@@ -38,4 +38,4 @@ jobs:
- name: isort
run: isort --profile black --check .
- name: pytest
- run: pytest
\ No newline at end of file
+ run: pytest
diff --git a/.vscode/settings.json b/.vscode/settings.json
index e94f3e9b..b5f4dfd1 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -7,11 +7,8 @@
"source.organizeImports": "explicit"
},
"python.analysis.typeCheckingMode": "basic",
-
"isort.args": [
"--profile black"
],
-
"black-formatter.importStrategy": "fromEnvironment",
-
}
\ No newline at end of file
diff --git a/notebooks/vis_demo.ipynb b/notebooks/vis_demo.ipynb
new file mode 100644
index 00000000..842804d0
--- /dev/null
+++ b/notebooks/vis_demo.ipynb
@@ -0,0 +1,148 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch; torch.set_grad_enabled(False)\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "\n",
+ "from delphi.eval.utils import tokenize, get_next_and_top_k_probs, load_validation_dataset\n",
+ "from delphi.eval.vis import vis_sample_prediction_probs\n",
+ "\n",
+ "model_name = \"roneneldan/TinyStories-1M\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
+ "ds = load_validation_dataset(\"tinystories-v2-clean\")\n",
+ "ds_txt = ds[\"story\"][:100]\n",
+ "ds_tok = [tokenize(tokenizer, txt) for txt in ds_txt]\n",
+ "sample_tok = ds_tok[0]\n",
+ "\n",
+ "correct_probs, top_3_probs = get_next_and_top_k_probs(model, sample_tok, k=3)\n",
+ "_, top_5_probs = get_next_and_top_k_probs(model, sample_tok, k=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### collect top k predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
<|endoftext|>
Once
upon
a
time
,
there
was
a
kind
girl
named
Lily
.
Lily
loved
to
mix
things
.
One
day
,
she
found
a
big
box
full
of
colors
.
Lily
was
very
happy
.
\\n
L
ily
took
out
a
strip
of
red
and
a
strip
of
blue
.
She
mixed
them
together
and
made
a
new
color
,
purple
!
Lily
was
so
excited
.
She
wanted
to
mix
more
colors
.
\\n
Next
,
Lily
took
a
strip
of
yellow
and
a
strip
of
green
.
She
mixed
them
together
and
made
a
new
color
,
orange
!
Lily
was
very
proud
of
herself
.
She
showed
her
new
colors
to
her
mom
and
dad
,
and
they
were
proud
of
her
too
.
They
all
lived
happily
ever
after
.
\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "_ = vis_sample_prediction_probs(sample_tok, correct_probs, top_3_probs, tokenizer)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " <|endoftext|>
Once
upon
a
time
,
there
was
a
kind
girl
named
Lily
.
Lily
loved
to
mix
things
.
One
day
,
she
found
a
big
box
full
of
colors
.
Lily
was
very
happy
.
\\n
L
ily
took
out
a
strip
of
red
and
a
strip
of
blue
.
She
mixed
them
together
and
made
a
new
color
,
purple
!
Lily
was
so
excited
.
She
wanted
to
mix
more
colors
.
\\n
Next
,
Lily
took
a
strip
of
yellow
and
a
strip
of
green
.
She
mixed
them
together
and
made
a
new
color
,
orange
!
Lily
was
very
proud
of
herself
.
She
showed
her
new
colors
to
her
mom
and
dad
,
and
they
were
proud
of
her
too
.
They
all
lived
happily
ever
after
.
\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "_ = vis_sample_prediction_probs(sample_tok, correct_probs, top_5_probs, tokenizer)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/delphi/eval/compare_models.py b/src/delphi/eval/compare_models.py
new file mode 100644
index 00000000..e03b300c
--- /dev/null
+++ b/src/delphi/eval/compare_models.py
@@ -0,0 +1,91 @@
+from dataclasses import dataclass
+
+import torch
+from jaxtyping import Int
+from transformers import PreTrainedModel
+
+from delphi.eval.utils import get_all_and_next_logprobs_single
+
+
+def identify_model(model: PreTrainedModel) -> str:
+ return model.config.name_or_path
+
+
+@dataclass
+class TokenPrediction:
+ token: int
+ base_model_prob: float
+ lift_model_prob: float
+
+
+@dataclass
+class NextTokenStats:
+ base_model: str
+ lift_model: str
+ next_prediction: TokenPrediction
+ topk: list[TokenPrediction]
+
+
+def compare_models(
+ model_a: PreTrainedModel,
+ model_b: PreTrainedModel,
+ sample_tok: Int[torch.Tensor, "seq"],
+ top_k: int = 3,
+) -> list[NextTokenStats | None]:
+ """
+ Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
+ Args:
+ - model_a: The first model (assumed to be the base model)
+ - model_b: The second model (assumed to be the improved model)
+ - sample_tok: The tokenized prompt
+ - top_k: The number of top token predictions to retrieve (default is 5)
+ Returns:
+ A list of NextTokenStats objects, one for each token in the prompt.
+ Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
+ """
+ assert (
+ model_a.device == model_b.device
+ ), "Both models must be on the same device for comparison."
+
+ device = model_a.device
+ sample_tok = sample_tok.to(device)
+
+ logprobs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
+ logprobs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)
+
+ probs_a = torch.exp(logprobs_a)
+ probs_b = torch.exp(logprobs_b)
+
+ top_k_b = torch.topk(probs_b, top_k, dim=-1)
+ top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)
+
+ top_k_b_tokens = top_k_b.indices
+ top_k_b_probs = top_k_b.values
+
+ comparisons = []
+ # ignore first token when evaluating predictions
+ comparisons.append(None)
+
+ for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
+ next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
+ ):
+ nts = NextTokenStats(
+ base_model=identify_model(model_a),
+ lift_model=identify_model(model_b),
+ next_prediction=TokenPrediction(
+ token=int(next_p_a.item()),
+ base_model_prob=next_p_a.item(),
+ lift_model_prob=next_p_b.item(),
+ ),
+ topk=[
+ TokenPrediction(
+ token=int(top_toks_b[i].item()),
+ base_model_prob=top_probs_a[i].item(),
+ lift_model_prob=top_probs_b[i].item(),
+ )
+ for i in range(top_k)
+ ],
+ )
+ comparisons.append(nts)
+
+ return comparisons
diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py
index ee9893d9..16b26e4c 100644
--- a/src/delphi/eval/utils.py
+++ b/src/delphi/eval/utils.py
@@ -4,6 +4,7 @@
import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
+from transformers import PreTrainedModel, PreTrainedTokenizerBase
def get_all_logprobs(
@@ -14,6 +15,13 @@ def get_all_logprobs(
return torch.log_softmax(logits, dim=-1)
+# convenience wrapper for calling on a single sample
+def get_single_logprobs(
+ model: Callable, input_ids: Int[torch.Tensor, "seq"]
+) -> Float[torch.Tensor, "seq vocab"]:
+ return get_all_logprobs(model, input_ids.unsqueeze(0))[0]
+
+
def gather_logprobs(
logprobs: Float[torch.Tensor, "batch seq vocab"],
tokens: Int[torch.Tensor, "batch seq"],
@@ -21,12 +29,39 @@ def gather_logprobs(
return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)
-def get_next_logprobs(
- model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
-) -> Float[torch.Tensor, "batch shorter_seq"]:
+def get_all_and_next_logprobs(
+ model: Callable,
+ input_ids: Int[torch.Tensor, "batch seq"],
+) -> tuple[
+ Float[torch.Tensor, "batch shorter_seq vocab"],
+ Float[torch.Tensor, "batch shorter_seq"],
+]:
logprobs = get_all_logprobs(model, input_ids[:, :-1])
next_tokens = input_ids[:, 1:]
- return gather_logprobs(logprobs, next_tokens)
+ return logprobs, gather_logprobs(logprobs, next_tokens)
+
+
+def get_all_and_next_logprobs_single(
+ model: Callable,
+ input_ids: Int[torch.Tensor, "seq"],
+) -> tuple[
+ Float[torch.Tensor, "shorter_seq vocab"],
+ Float[torch.Tensor, "shorter_seq"],
+]:
+ all_logprobs, next_logprobs = get_all_and_next_logprobs(
+ model, input_ids.unsqueeze(0)
+ )
+ return all_logprobs[0], next_logprobs[0]
+
+
+def get_next_and_top_k_probs(
+ model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
+) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]:
+ all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids)
+ all_probs = torch.exp(all_logprobs)
+ next_probs = torch.exp(next_logprobs)
+ top_k = torch.topk(all_probs, k, dim=-1)
+ return next_probs, top_k
def load_validation_dataset(dataset_name: str) -> Dataset:
@@ -42,3 +77,13 @@ def load_validation_dataset(dataset_name: str) -> Dataset:
split="train",
)
return cast(Dataset, dataset)
+
+
+def tokenize(
+ tokenizer: PreTrainedTokenizerBase, sample_txt: str
+) -> Int[torch.Tensor, "seq"]:
+ # supposedly this can be different than prepending the bos token id
+ return cast(
+ Int[torch.Tensor, "seq"],
+ tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
+ )
diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py
new file mode 100644
index 00000000..5dd4fdb2
--- /dev/null
+++ b/src/delphi/eval/vis.py
@@ -0,0 +1,140 @@
+import uuid
+from typing import cast
+
+import torch
+from IPython.core.display import HTML
+from IPython.core.display_functions import display
+from jaxtyping import Float, Int
+from transformers import PreTrainedTokenizerBase
+
+
+def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]:
+ # for the endoftext token
+ # no prediction, no color
+ colors = ["white"]
+ for p in probs.tolist():
+ red_gap = 150 # the higher it is, the less red the tokens will be
+ green_blue_val = red_gap + int((255 - red_gap) * (1 - p))
+ colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})")
+ return colors
+
+
+def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str:
+ tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n")
+ prob_str = f"{prob:.2%}"
+ return f"{prob_str:>6} |{tok_str}|"
+
+
+def token_to_html(
+ token: int,
+ tokenizer: PreTrainedTokenizerBase,
+ bg_color: str,
+ data: dict,
+) -> str:
+ data = data or {} # equivalent to if not data: data = {}
+ # non-breakable space, w/o it leading spaces wouldn't be displayed
+ str_token = tokenizer.decode(token).replace(" ", " ")
+
+ # background or user-select (for \n) goes here
+ specific_styles = {}
+ # for now just adds line break or doesn't
+ br = ""
+
+ if bg_color:
+ specific_styles["background-color"] = bg_color
+ if str_token == "\n":
+ # replace new line character with two characters: \ and n
+ str_token = r"\n"
+ # add line break in html
+ br += "
"
+ # this is so we can copy the prompt without "\n"s
+ specific_styles["user-select"] = "none"
+
+ style_str = data_str = ""
+ # converting style dict into the style attribute
+ if specific_styles:
+ inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
+ style_str = f" style='{inside_style_str}'"
+ if data:
+ data_str = "".join(
+ f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items()
+ )
+ return f"{str_token}
{br}"
+
+
+_token_style = {
+ "border": "1px solid #888",
+ "display": "inline-block",
+ # each character of the same width, so we can easily spot a space
+ "font-family": "monospace",
+ "font-size": "14px",
+ "color": "black",
+ "background-color": "white",
+ "margin": "1px 0px 1px 1px",
+ "padding": "0px 1px 1px 1px",
+}
+_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])
+
+
+def vis_sample_prediction_probs(
+ sample_tok: Int[torch.Tensor, "pos"],
+ correct_probs: Float[torch.Tensor, "pos"],
+ top_k_probs: torch.return_types.topk,
+ tokenizer: PreTrainedTokenizerBase,
+) -> str:
+ colors = probs_to_colors(correct_probs)
+ token_htmls = []
+
+ # Generate a unique ID for this instance (so we can have multiple instances on the same page)
+ unique_id = str(uuid.uuid4())
+
+ token_class = f"token_{unique_id}"
+ hover_div_id = f"hover_info_{unique_id}"
+
+ for i in range(sample_tok.shape[0]):
+ tok = cast(int, sample_tok[i].item())
+ data = {}
+ if i > 0:
+ correct_prob = correct_probs[i - 1].item()
+ data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer)
+ top_k_probs_tokens = top_k_probs.indices[i - 1]
+ top_k_probs_values = top_k_probs.values[i - 1]
+ for j in range(top_k_probs_tokens.shape[0]):
+ top_tok = top_k_probs_tokens[j].item()
+ top_tok = cast(int, top_tok)
+ top_prob = top_k_probs_values[j].item()
+ data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)
+
+ token_htmls.append(
+ token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
+ "class='token'", f"class='{token_class}'"
+ )
+ )
+
+ html_str = f"""
+
+ {"".join(token_htmls)}
+
+ """
+ display(HTML(html_str))
+ return html_str
diff --git a/tests/eval/test_compare_models.py b/tests/eval/test_compare_models.py
new file mode 100644
index 00000000..0521b0cb
--- /dev/null
+++ b/tests/eval/test_compare_models.py
@@ -0,0 +1,23 @@
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from delphi.eval.compare_models import NextTokenStats, compare_models
+from delphi.eval.utils import load_validation_dataset, tokenize
+
+
+def test_compare_models():
+ with torch.set_grad_enabled(False):
+ model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
+ model_instruct = AutoModelForCausalLM.from_pretrained(
+ "roneneldan/TinyStories-Instruct-1M"
+ )
+ ds_txt = load_validation_dataset("tinystories-v2-clean")["story"]
+ tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
+ sample_tok = tokenize(tokenizer, ds_txt[0])
+ K = 3
+ model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
+ # ignore the first element comparison
+ assert model_comparison[0] is None
+ assert isinstance(model_comparison[1], NextTokenStats)
+ assert len(model_comparison) == sample_tok.shape[0]
+ assert len(model_comparison[1].topk) == K