diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index e20ccd24..89a889fa 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '*' + - "*" permissions: actions: write @@ -38,4 +38,4 @@ jobs: - name: isort run: isort --profile black --check . - name: pytest - run: pytest \ No newline at end of file + run: pytest diff --git a/.vscode/settings.json b/.vscode/settings.json index e94f3e9b..b5f4dfd1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,11 +7,8 @@ "source.organizeImports": "explicit" }, "python.analysis.typeCheckingMode": "basic", - "isort.args": [ "--profile black" ], - "black-formatter.importStrategy": "fromEnvironment", - } \ No newline at end of file diff --git a/notebooks/vis_demo.ipynb b/notebooks/vis_demo.ipynb new file mode 100644 index 00000000..842804d0 --- /dev/null +++ b/notebooks/vis_demo.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch; torch.set_grad_enabled(False)\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "\n", + "from delphi.eval.utils import tokenize, get_next_and_top_k_probs, load_validation_dataset\n", + "from delphi.eval.vis import vis_sample_prediction_probs\n", + "\n", + "model_name = \"roneneldan/TinyStories-1M\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "model = AutoModelForCausalLM.from_pretrained(model_name)\n", + "ds = load_validation_dataset(\"tinystories-v2-clean\")\n", + "ds_txt = ds[\"story\"][:100]\n", + "ds_tok = [tokenize(tokenizer, txt) for txt in ds_txt]\n", + "sample_tok = ds_tok[0]\n", + "\n", + "correct_probs, top_3_probs = get_next_and_top_k_probs(model, sample_tok, k=3)\n", + "_, top_5_probs = get_next_and_top_k_probs(model, sample_tok, k=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### collect top k predictions" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + "
<|endoftext|>
Once
 upon
 a
 time
,
 there
 was
 a
 kind
 girl
 named
 Lily
.
 Lily
 loved
 to
 mix
 things
.
 One
 day
,
 she
 found
 a
 big
 box
 full
 of
 colors
.
 Lily
 was
 very
 happy
.
\\n

L
ily
 took
 out
 a
 strip
 of
 red
 and
 a
 strip
 of
 blue
.
 She
 mixed
 them
 together
 and
 made
 a
 new
 color
,
 purple
!
 Lily
 was
 so
 excited
.
 She
 wanted
 to
 mix
 more
 colors
.
\\n

Next
,
 Lily
 took
 a
 strip
 of
 yellow
 and
 a
 strip
 of
 green
.
 She
 mixed
 them
 together
 and
 made
 a
 new
 color
,
 orange
!
 Lily
 was
 very
 proud
 of
 herself
.
 She
 showed
 her
 new
 colors
 to
 her
 mom
 and
 dad
,
 and
 they
 were
 proud
 of
 her
 too
.
 They
 all
 lived
 happily
 ever
 after
.
\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_ = vis_sample_prediction_probs(sample_tok, correct_probs, top_3_probs, tokenizer)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + "
<|endoftext|>
Once
 upon
 a
 time
,
 there
 was
 a
 kind
 girl
 named
 Lily
.
 Lily
 loved
 to
 mix
 things
.
 One
 day
,
 she
 found
 a
 big
 box
 full
 of
 colors
.
 Lily
 was
 very
 happy
.
\\n

L
ily
 took
 out
 a
 strip
 of
 red
 and
 a
 strip
 of
 blue
.
 She
 mixed
 them
 together
 and
 made
 a
 new
 color
,
 purple
!
 Lily
 was
 so
 excited
.
 She
 wanted
 to
 mix
 more
 colors
.
\\n

Next
,
 Lily
 took
 a
 strip
 of
 yellow
 and
 a
 strip
 of
 green
.
 She
 mixed
 them
 together
 and
 made
 a
 new
 color
,
 orange
!
 Lily
 was
 very
 proud
 of
 herself
.
 She
 showed
 her
 new
 colors
 to
 her
 mom
 and
 dad
,
 and
 they
 were
 proud
 of
 her
 too
.
 They
 all
 lived
 happily
 ever
 after
.
\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_ = vis_sample_prediction_probs(sample_tok, correct_probs, top_5_probs, tokenizer)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/delphi/eval/compare_models.py b/src/delphi/eval/compare_models.py new file mode 100644 index 00000000..e03b300c --- /dev/null +++ b/src/delphi/eval/compare_models.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass + +import torch +from jaxtyping import Int +from transformers import PreTrainedModel + +from delphi.eval.utils import get_all_and_next_logprobs_single + + +def identify_model(model: PreTrainedModel) -> str: + return model.config.name_or_path + + +@dataclass +class TokenPrediction: + token: int + base_model_prob: float + lift_model_prob: float + + +@dataclass +class NextTokenStats: + base_model: str + lift_model: str + next_prediction: TokenPrediction + topk: list[TokenPrediction] + + +def compare_models( + model_a: PreTrainedModel, + model_b: PreTrainedModel, + sample_tok: Int[torch.Tensor, "seq"], + top_k: int = 3, +) -> list[NextTokenStats | None]: + """ + Compare the probabilities of the next token for two models and get the top k token predictions according to model B. + Args: + - model_a: The first model (assumed to be the base model) + - model_b: The second model (assumed to be the improved model) + - sample_tok: The tokenized prompt + - top_k: The number of top token predictions to retrieve (default is 5) + Returns: + A list of NextTokenStats objects, one for each token in the prompt. + Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor) + """ + assert ( + model_a.device == model_b.device + ), "Both models must be on the same device for comparison." + + device = model_a.device + sample_tok = sample_tok.to(device) + + logprobs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok) + logprobs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok) + + probs_a = torch.exp(logprobs_a) + probs_b = torch.exp(logprobs_b) + + top_k_b = torch.topk(probs_b, top_k, dim=-1) + top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices) + + top_k_b_tokens = top_k_b.indices + top_k_b_probs = top_k_b.values + + comparisons = [] + # ignore first token when evaluating predictions + comparisons.append(None) + + for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip( + next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs + ): + nts = NextTokenStats( + base_model=identify_model(model_a), + lift_model=identify_model(model_b), + next_prediction=TokenPrediction( + token=int(next_p_a.item()), + base_model_prob=next_p_a.item(), + lift_model_prob=next_p_b.item(), + ), + topk=[ + TokenPrediction( + token=int(top_toks_b[i].item()), + base_model_prob=top_probs_a[i].item(), + lift_model_prob=top_probs_b[i].item(), + ) + for i in range(top_k) + ], + ) + comparisons.append(nts) + + return comparisons diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index ee9893d9..16b26e4c 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -4,6 +4,7 @@ import torch from datasets import Dataset, load_dataset from jaxtyping import Float, Int +from transformers import PreTrainedModel, PreTrainedTokenizerBase def get_all_logprobs( @@ -14,6 +15,13 @@ def get_all_logprobs( return torch.log_softmax(logits, dim=-1) +# convenience wrapper for calling on a single sample +def get_single_logprobs( + model: Callable, input_ids: Int[torch.Tensor, "seq"] +) -> Float[torch.Tensor, "seq vocab"]: + return get_all_logprobs(model, input_ids.unsqueeze(0))[0] + + def gather_logprobs( logprobs: Float[torch.Tensor, "batch seq vocab"], tokens: Int[torch.Tensor, "batch seq"], @@ -21,12 +29,39 @@ def gather_logprobs( return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1) -def get_next_logprobs( - model: Callable, input_ids: Int[torch.Tensor, "batch seq"] -) -> Float[torch.Tensor, "batch shorter_seq"]: +def get_all_and_next_logprobs( + model: Callable, + input_ids: Int[torch.Tensor, "batch seq"], +) -> tuple[ + Float[torch.Tensor, "batch shorter_seq vocab"], + Float[torch.Tensor, "batch shorter_seq"], +]: logprobs = get_all_logprobs(model, input_ids[:, :-1]) next_tokens = input_ids[:, 1:] - return gather_logprobs(logprobs, next_tokens) + return logprobs, gather_logprobs(logprobs, next_tokens) + + +def get_all_and_next_logprobs_single( + model: Callable, + input_ids: Int[torch.Tensor, "seq"], +) -> tuple[ + Float[torch.Tensor, "shorter_seq vocab"], + Float[torch.Tensor, "shorter_seq"], +]: + all_logprobs, next_logprobs = get_all_and_next_logprobs( + model, input_ids.unsqueeze(0) + ) + return all_logprobs[0], next_logprobs[0] + + +def get_next_and_top_k_probs( + model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3 +) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]: + all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids) + all_probs = torch.exp(all_logprobs) + next_probs = torch.exp(next_logprobs) + top_k = torch.topk(all_probs, k, dim=-1) + return next_probs, top_k def load_validation_dataset(dataset_name: str) -> Dataset: @@ -42,3 +77,13 @@ def load_validation_dataset(dataset_name: str) -> Dataset: split="train", ) return cast(Dataset, dataset) + + +def tokenize( + tokenizer: PreTrainedTokenizerBase, sample_txt: str +) -> Int[torch.Tensor, "seq"]: + # supposedly this can be different than prepending the bos token id + return cast( + Int[torch.Tensor, "seq"], + tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0], + ) diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py new file mode 100644 index 00000000..5dd4fdb2 --- /dev/null +++ b/src/delphi/eval/vis.py @@ -0,0 +1,140 @@ +import uuid +from typing import cast + +import torch +from IPython.core.display import HTML +from IPython.core.display_functions import display +from jaxtyping import Float, Int +from transformers import PreTrainedTokenizerBase + + +def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]: + # for the endoftext token + # no prediction, no color + colors = ["white"] + for p in probs.tolist(): + red_gap = 150 # the higher it is, the less red the tokens will be + green_blue_val = red_gap + int((255 - red_gap) * (1 - p)) + colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})") + return colors + + +def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str: + tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n") + prob_str = f"{prob:.2%}" + return f"{prob_str:>6} |{tok_str}|" + + +def token_to_html( + token: int, + tokenizer: PreTrainedTokenizerBase, + bg_color: str, + data: dict, +) -> str: + data = data or {} # equivalent to if not data: data = {} + # non-breakable space, w/o it leading spaces wouldn't be displayed + str_token = tokenizer.decode(token).replace(" ", " ") + + # background or user-select (for \n) goes here + specific_styles = {} + # for now just adds line break or doesn't + br = "" + + if bg_color: + specific_styles["background-color"] = bg_color + if str_token == "\n": + # replace new line character with two characters: \ and n + str_token = r"\n" + # add line break in html + br += "
" + # this is so we can copy the prompt without "\n"s + specific_styles["user-select"] = "none" + + style_str = data_str = "" + # converting style dict into the style attribute + if specific_styles: + inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items()) + style_str = f" style='{inside_style_str}'" + if data: + data_str = "".join( + f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items() + ) + return f"
{str_token}
{br}" + + +_token_style = { + "border": "1px solid #888", + "display": "inline-block", + # each character of the same width, so we can easily spot a space + "font-family": "monospace", + "font-size": "14px", + "color": "black", + "background-color": "white", + "margin": "1px 0px 1px 1px", + "padding": "0px 1px 1px 1px", +} +_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()]) + + +def vis_sample_prediction_probs( + sample_tok: Int[torch.Tensor, "pos"], + correct_probs: Float[torch.Tensor, "pos"], + top_k_probs: torch.return_types.topk, + tokenizer: PreTrainedTokenizerBase, +) -> str: + colors = probs_to_colors(correct_probs) + token_htmls = [] + + # Generate a unique ID for this instance (so we can have multiple instances on the same page) + unique_id = str(uuid.uuid4()) + + token_class = f"token_{unique_id}" + hover_div_id = f"hover_info_{unique_id}" + + for i in range(sample_tok.shape[0]): + tok = cast(int, sample_tok[i].item()) + data = {} + if i > 0: + correct_prob = correct_probs[i - 1].item() + data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer) + top_k_probs_tokens = top_k_probs.indices[i - 1] + top_k_probs_values = top_k_probs.values[i - 1] + for j in range(top_k_probs_tokens.shape[0]): + top_tok = top_k_probs_tokens[j].item() + top_tok = cast(int, top_tok) + top_prob = top_k_probs_values[j].item() + data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer) + + token_htmls.append( + token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace( + "class='token'", f"class='{token_class}'" + ) + ) + + html_str = f""" + + {"".join(token_htmls)}
+ + """ + display(HTML(html_str)) + return html_str diff --git a/tests/eval/test_compare_models.py b/tests/eval/test_compare_models.py new file mode 100644 index 00000000..0521b0cb --- /dev/null +++ b/tests/eval/test_compare_models.py @@ -0,0 +1,23 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from delphi.eval.compare_models import NextTokenStats, compare_models +from delphi.eval.utils import load_validation_dataset, tokenize + + +def test_compare_models(): + with torch.set_grad_enabled(False): + model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M") + model_instruct = AutoModelForCausalLM.from_pretrained( + "roneneldan/TinyStories-Instruct-1M" + ) + ds_txt = load_validation_dataset("tinystories-v2-clean")["story"] + tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M") + sample_tok = tokenize(tokenizer, ds_txt[0]) + K = 3 + model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K) + # ignore the first element comparison + assert model_comparison[0] is None + assert isinstance(model_comparison[1], NextTokenStats) + assert len(model_comparison) == sample_tok.shape[0] + assert len(model_comparison[1].topk) == K