diff --git a/notebooks/eval_notebook.ipynb b/notebooks/eval_notebook.ipynb
index 864a1f1e..daed1f52 100644
--- a/notebooks/eval_notebook.ipynb
+++ b/notebooks/eval_notebook.ipynb
@@ -68,12 +68,12 @@
"data": {
"application/vnd.holoviews_exec.v0+json": "",
"text/html": [
- "
\n",
- "
\n",
+ "
\n",
"
+ """
+ display(HTML(html_str))
+
+
+def token_selector(
+ vocab_map: dict[str, int]
+) -> tuple[pn.widgets.MultiChoice, list[int]]:
+ tokens = list(vocab_map.keys())
+ token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens)
+ token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)]
+
+ def update_tokens(event):
+ token_ids.clear()
+ token_ids.extend([vocab_map[token] for token in event.new])
+
+ token_selector_.param.watch(update_tokens, "value")
+ return token_selector_, token_ids
+
+
+def calc_model_group_stats(
+ tokenized_corpus_dataset: Dataset,
+ logprobs_by_dataset: dict[str, torch.Tensor],
+ selected_tokens: list[int],
+) -> dict[str, dict[str, float]]:
+ """
+ For each (model, token group) pair, calculate useful stats (for visualization)
+
+ args:
+ - tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
+ - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
+ - selected_tokens: a list of selected token IDs, e.g. [46, 402, ...]
+
+ returns: a dict of model names as keys and stats dict as values
+ e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
+
+ Stats calculated: mean, median, min, max, 25th percentile, 75th percentile
+ """
+ model_group_stats = {}
+ for model in logprobs_by_dataset:
+ model_logprobs = []
+ print(f"Processing model {model}")
+ dataset = logprobs_by_dataset[model]
+ for ix_doc_lp, document_lps in enumerate(dataset):
+ tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
+ for ix_token, token in enumerate(tokens):
+ if ix_token == 0: # skip the first token, which isn't predicted
+ continue
+ logprob = document_lps[ix_token].item()
+ if token in selected_tokens:
+ model_logprobs.append(logprob)
+
+ if model_logprobs:
+ model_group_stats[model] = {
+ "mean": np.mean(model_logprobs),
+ "median": np.median(model_logprobs),
+ "min": np.min(model_logprobs),
+ "max": np.max(model_logprobs),
+ "25th": np.percentile(model_logprobs, 25),
+ "75th": np.percentile(model_logprobs, 75),
+ }
+ return model_group_stats
+
+
+def dict_filter_quantile(
+ d: dict[Any, float], q_start: float, q_end: float
+) -> dict[Any, float]:
+ if not (0 <= q_start < q_end <= 1):
+ raise ValueError("Invalid quantile range")
+ q_start_val = np.nanquantile(list(d.values()), q_start)
+ q_end_val = np.nanquantile(list(d.values()), q_end)
+ return {
+ k: v for k, v in d.items() if q_start_val <= v <= q_end_val and not np.isnan(v)
+ }
+
+
+def get_all_tok_metrics_in_label(
+ token_ids: Int[torch.Tensor, "prompt pos"],
+ selected_tokens: list[int],
+ metrics: torch.Tensor,
+ q_start: Optional[float] = None,
+ q_end: Optional[float] = None,
+) -> dict[tuple[int, int], float]:
+ """
+ From the token_map, get all the positions of the tokens that have a certain label.
+ We don't use the token_map because for sampling purposes, iterating through token_ids is more efficient.
+ Optionally, filter the tokens based on the quantile range of the metrics.
+
+ Args:
+ - token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]}
+ - selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...]
+ - metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...])
+ - q_start (float): the start of the quantile range to filter the metrics e.g. 0.1
+ - q_end (float): the end of the quantile range to filter the metrics e.g. 0.9
+
+ Returns:
+ - tok_positions (dict[tuple[int, int], Number]): dictionary of token positions and their corresponding metrics
+ """
+
+ # check if metrics have the same dimensions as token_ids
+ if metrics.shape != token_ids.shape:
+ raise ValueError(
+ f"Expected metrics to have the same shape as token_ids, but got {metrics.shape} and {token_ids.shape} instead."
+ )
+
+ tok_positions = {}
+ for prompt_pos, prompt in enumerate(token_ids.numpy()):
+ for tok_pos, tok in enumerate(prompt):
+ if tok in selected_tokens:
+ tok_positions[(prompt_pos, tok_pos)] = metrics[
+ prompt_pos, tok_pos
+ ].item()
+
+ if q_start is not None and q_end is not None:
+ tok_positions = dict_filter_quantile(tok_positions, q_start, q_end)
+
+ return tok_positions
+
+
+def visualize_selected_tokens(
+ input: dict[str | int, tuple[float, float, float]],
+ log_scale=False,
+ line_metric="Means",
+ checkpoint_mode=True,
+ shade_color="rgba(68, 68, 68, 0.3)",
+ line_color="rgb(31, 119, 180)",
+ bar_color="purple",
+ marker_color="SkyBlue",
+ background_color="AliceBlue",
+) -> go.FigureWidget:
+ input_x = list(input.keys())
+
+ def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
+ return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]
+
+ def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ x = np.array([input[x] for x in input_x]).T
+ means, err_lo, err_hi = x[0], x[1], x[2]
+ return means, err_lo, err_hi
+
+ means, err_lo, err_hi = get_plot_values()
+
+ if checkpoint_mode:
+ scatter_plot = go.Figure(
+ [
+ go.Scatter(
+ name="Upper Bound",
+ x=input_x,
+ y=means + err_hi,
+ mode="lines",
+ marker=dict(color=shade_color),
+ line=dict(width=0),
+ showlegend=False,
+ ),
+ go.Scatter(
+ name="Lower Bound",
+ x=input_x,
+ y=means - err_lo,
+ marker=dict(color=shade_color),
+ line=dict(width=0),
+ mode="lines",
+ fillcolor=shade_color,
+ fill="tonexty",
+ showlegend=False,
+ ),
+ go.Scatter(
+ name=line_metric,
+ x=input_x,
+ y=means,
+ mode="lines",
+ marker=dict(
+ color=line_color,
+ size=0,
+ line=dict(color=line_color, width=1),
+ ),
+ ),
+ ]
+ )
+ else:
+ scatter_plot = go.Scatter(
+ x=input_x,
+ y=means,
+ error_y=dict(
+ type="data",
+ symmetric=False,
+ array=err_hi,
+ arrayminus=err_lo,
+ color=bar_color,
+ ),
+ marker=dict(
+ color=marker_color,
+ size=15,
+ line=dict(color=line_color, width=2),
+ ),
+ hovertext=get_hovertexts(means, err_lo, err_hi),
+ hoverinfo="text+x",
+ )
+ g = go.FigureWidget(
+ data=scatter_plot,
+ layout=go.Layout(
+ yaxis=dict(
+ title="Loss",
+ type="log" if log_scale else "linear",
+ ),
+ plot_bgcolor=background_color,
+ ),
+ )
+
+ return g
diff --git a/src/delphi/eval/__init__.py b/src/delphi/eval/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py
deleted file mode 100644
index faab8a02..00000000
--- a/src/delphi/eval/calc_model_group_stats.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import numpy as np
-import torch
-from datasets import Dataset
-from jaxtyping import Float
-
-
-def calc_model_group_stats(
- tokenized_corpus_dataset: Dataset,
- logprobs_by_dataset: dict[str, torch.Tensor],
- selected_tokens: list[int],
-) -> dict[str, dict[str, float]]:
- """
- For each (model, token group) pair, calculate useful stats (for visualization)
-
- args:
- - tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- - selected_tokens: a list of selected token IDs, e.g. [46, 402, ...]
-
- returns: a dict of model names as keys and stats dict as values
- e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
-
- Stats calculated: mean, median, min, max, 25th percentile, 75th percentile
- """
- model_group_stats = {}
- for model in logprobs_by_dataset:
- model_logprobs = []
- print(f"Processing model {model}")
- dataset = logprobs_by_dataset[model]
- for ix_doc_lp, document_lps in enumerate(dataset):
- tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
- for ix_token, token in enumerate(tokens):
- if ix_token == 0: # skip the first token, which isn't predicted
- continue
- logprob = document_lps[ix_token].item()
- if token in selected_tokens:
- model_logprobs.append(logprob)
-
- if model_logprobs:
- model_group_stats[model] = {
- "mean": np.mean(model_logprobs),
- "median": np.median(model_logprobs),
- "min": np.min(model_logprobs),
- "max": np.max(model_logprobs),
- "25th": np.percentile(model_logprobs, 25),
- "75th": np.percentile(model_logprobs, 75),
- }
- return model_group_stats
diff --git a/src/delphi/eval/token_positions.py b/src/delphi/eval/token_positions.py
deleted file mode 100644
index a98af761..00000000
--- a/src/delphi/eval/token_positions.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from typing import Optional
-
-import torch
-from jaxtyping import Int
-
-from delphi.eval.utils import dict_filter_quantile
-
-
-def get_all_tok_metrics_in_label(
- token_ids: Int[torch.Tensor, "prompt pos"],
- selected_tokens: list[int],
- metrics: torch.Tensor,
- q_start: Optional[float] = None,
- q_end: Optional[float] = None,
-) -> dict[tuple[int, int], float]:
- """
- From the token_map, get all the positions of the tokens that have a certain label.
- We don't use the token_map because for sampling purposes, iterating through token_ids is more efficient.
- Optionally, filter the tokens based on the quantile range of the metrics.
-
- Args:
- - token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]}
- - selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...]
- - metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...])
- - q_start (float): the start of the quantile range to filter the metrics e.g. 0.1
- - q_end (float): the end of the quantile range to filter the metrics e.g. 0.9
-
- Returns:
- - tok_positions (dict[tuple[int, int], Number]): dictionary of token positions and their corresponding metrics
- """
-
- # check if metrics have the same dimensions as token_ids
- if metrics.shape != token_ids.shape:
- raise ValueError(
- f"Expected metrics to have the same shape as token_ids, but got {metrics.shape} and {token_ids.shape} instead."
- )
-
- tok_positions = {}
- for prompt_pos, prompt in enumerate(token_ids.numpy()):
- for tok_pos, tok in enumerate(prompt):
- if tok in selected_tokens:
- tok_positions[(prompt_pos, tok_pos)] = metrics[
- prompt_pos, tok_pos
- ].item()
-
- if q_start is not None and q_end is not None:
- tok_positions = dict_filter_quantile(tok_positions, q_start, q_end)
-
- return tok_positions
diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py
deleted file mode 100644
index 0026e7a7..00000000
--- a/src/delphi/eval/utils.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from collections.abc import Callable
-from typing import Any
-
-import numpy as np
-import torch
-from jaxtyping import Float, Int
-from transformers import PreTrainedModel
-
-
-def get_all_logprobs(
- model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
-) -> Float[torch.Tensor, "batch seq vocab"]:
- # batch, seq, vocab
- logits = model(input_ids).logits
- return torch.log_softmax(logits, dim=-1)
-
-
-# convenience wrapper for calling on a single sample
-def get_single_logprobs(
- model: Callable, input_ids: Int[torch.Tensor, "seq"]
-) -> Float[torch.Tensor, "seq vocab"]:
- return get_all_logprobs(model, input_ids.unsqueeze(0))[0]
-
-
-def gather_logprobs(
- logprobs: Float[torch.Tensor, "batch seq vocab"],
- tokens: Int[torch.Tensor, "batch seq"],
-) -> Float[torch.Tensor, "batch seq"]:
- return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)
-
-
-def get_all_and_next_logprobs(
- model: Callable,
- input_ids: Int[torch.Tensor, "batch seq"],
-) -> tuple[
- Float[torch.Tensor, "batch shorter_seq vocab"],
- Float[torch.Tensor, "batch shorter_seq"],
-]:
- logprobs = get_all_logprobs(model, input_ids[:, :-1])
- next_tokens = input_ids[:, 1:]
- return logprobs, gather_logprobs(logprobs, next_tokens)
-
-
-def get_all_and_next_logprobs_single(
- model: Callable,
- input_ids: Int[torch.Tensor, "seq"],
-) -> tuple[
- Float[torch.Tensor, "shorter_seq vocab"],
- Float[torch.Tensor, "shorter_seq"],
-]:
- all_logprobs, next_logprobs = get_all_and_next_logprobs(
- model, input_ids.unsqueeze(0)
- )
- return all_logprobs[0], next_logprobs[0]
-
-
-def get_next_and_top_k_probs(
- model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
-) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]:
- all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids)
- all_probs = torch.exp(all_logprobs)
- next_probs = torch.exp(next_logprobs)
- top_k = torch.topk(all_probs, k, dim=-1)
- return next_probs, top_k
-
-
-def dict_filter_quantile(
- d: dict[Any, float], q_start: float, q_end: float
-) -> dict[Any, float]:
- if not (0 <= q_start < q_end <= 1):
- raise ValueError("Invalid quantile range")
- q_start_val = np.nanquantile(list(d.values()), q_start)
- q_end_val = np.nanquantile(list(d.values()), q_end)
- return {
- k: v for k, v in d.items() if q_start_val <= v <= q_end_val and not np.isnan(v)
- }
diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py
deleted file mode 100644
index cea76e88..00000000
--- a/src/delphi/eval/vis.py
+++ /dev/null
@@ -1,192 +0,0 @@
-import math
-import random
-import uuid
-from typing import cast
-
-import numpy as np
-import panel as pn
-import torch
-from IPython.core.display import HTML
-from IPython.core.display_functions import display
-from jaxtyping import Float, Int
-from transformers import PreTrainedTokenizerBase
-
-
-def single_loss_diff_to_color(loss_diff: float) -> str:
- # if loss_diff is negative, we want the color to be red
- # if loss_diff is positive, we want the color to be green
- # if loss_diff is 0, we want the color to be white
- # the color should be more intense the larger the absolute value of loss_diff
-
- def sigmoid(x: float) -> float:
- return 1 / (1 + math.exp(-x))
-
- scaled_loss_diff = sigmoid(loss_diff) # scale to 0-1
-
- if scaled_loss_diff < 0.5: # red
- red_val = 255
- green_blue_val = min(int(255 * 2 * scaled_loss_diff), 255)
- return f"rgb({red_val}, {green_blue_val}, {green_blue_val})"
- else: # green
- green_val = 255
- red_blue_val = min(int(255 * 2 * (1 - scaled_loss_diff)), 255)
- return f"rgb({red_blue_val}, {green_val}, {red_blue_val})"
-
-
-def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str:
- tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n")
- prob_str = f"{prob:.2%}"
- return f"{prob_str:>6} |{tok_str}|"
-
-
-def token_to_html(
- token: int,
- tokenizer: PreTrainedTokenizerBase,
- bg_color: str,
- data: dict,
- class_name: str = "token",
-) -> str:
- data = data or {} # equivalent to if not data: data = {}
- # non-breakable space, w/o it leading spaces wouldn't be displayed
- str_token = tokenizer.decode(token).replace(" ", " ")
-
- # background or user-select (for \n) goes here
- specific_styles = {}
- # for now just adds line break or doesn't
- br = ""
-
- if bg_color:
- specific_styles["background-color"] = bg_color
- if str_token == "\n":
- # replace new line character with two characters: \ and n
- str_token = r"\n"
- # add line break in html
- br += "
"
- # this is so we can copy the prompt without "\n"s
- specific_styles["user-select"] = "none"
- str_token = str_token.replace("<", "<").replace(">", ">")
-
- style_str = data_str = ""
- # converting style dict into the style attribute
- if specific_styles:
- inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
- style_str = f" style='{inside_style_str}'"
- if data:
- data_str = "".join(
- f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items()
- )
- return f"
{str_token}
{br}"
-
-
-_token_style = {
- "border": "1px solid #888",
- "display": "inline-block",
- # each character of the same width, so we can easily spot a space
- "font-family": "monospace",
- "font-size": "14px",
- "color": "black",
- "background-color": "white",
- "margin": "1px 0px 1px 1px",
- "padding": "0px 1px 1px 1px",
-}
-_token_emphasized_style = {
- "border": "3px solid #888",
- "display": "inline-block",
- "font-family": "monospace",
- "font-size": "14px",
- "color": "black",
- "background-color": "white",
- "margin": "1px 0px 1px 1px",
- "padding": "0px 1px 1px 1px",
-}
-_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])
-_token_emphasized_style_str = " ".join(
- [f"{k}: {v};" for k, v in _token_emphasized_style.items()]
-)
-
-
-def vis_pos_map(
- pos_list: list[tuple[int, int]],
- selected_tokens: list[int],
- metrics: Float[torch.Tensor, "prompt pos"],
- token_ids: Int[torch.Tensor, "prompt pos"],
- tokenizer: PreTrainedTokenizerBase,
-):
- """
- Randomly sample from pos_map and visualize the loss diff at the corresponding position.
- """
-
- token_htmls = []
- unique_id = str(uuid.uuid4())
- token_class = f"pretoken_{unique_id}"
- selected_token_class = f"token_{unique_id}"
- hover_div_id = f"hover_info_{unique_id}"
-
- # choose a random keys from pos_map
- key = random.choice(pos_list)
-
- prompt, pos = key
- all_toks = token_ids[prompt][: pos + 1]
-
- for i in range(all_toks.shape[0]):
- token_id = cast(int, all_toks[i].item())
- value = metrics[prompt][i].item()
- token_htmls.append(
- token_to_html(
- token_id,
- tokenizer,
- bg_color="white"
- if np.isnan(value)
- else single_loss_diff_to_color(value),
- data={"loss-diff": f"{value:.2f}"},
- class_name=token_class
- if token_id not in selected_tokens
- else selected_token_class,
- )
- )
-
- # add break line
- token_htmls.append("
")
-
- html_str = f"""
-
- {"".join(token_htmls)}
-
- """
- display(HTML(html_str))
-
-
-def token_selector(
- vocab_map: dict[str, int]
-) -> tuple[pn.widgets.MultiChoice, list[int]]:
- tokens = list(vocab_map.keys())
- token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens)
- token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)]
-
- def update_tokens(event):
- token_ids.clear()
- token_ids.extend([vocab_map[token] for token in event.new])
-
- token_selector_.param.watch(update_tokens, "value")
- return token_selector_, token_ids
diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py
deleted file mode 100644
index e5d735f4..00000000
--- a/src/delphi/eval/vis_per_token_model.py
+++ /dev/null
@@ -1,96 +0,0 @@
-from typing import Union
-
-import numpy as np
-import plotly.graph_objects as go
-
-
-def visualize_selected_tokens(
- input: dict[Union[str, int], tuple[float, float, float]],
- log_scale=False,
- line_metric="Means",
- checkpoint_mode=True,
- shade_color="rgba(68, 68, 68, 0.3)",
- line_color="rgb(31, 119, 180)",
- bar_color="purple",
- marker_color="SkyBlue",
- background_color="AliceBlue",
-) -> go.FigureWidget:
- input_x = list(input.keys())
-
- def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
- return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]
-
- def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
- x = np.array([input[x] for x in input_x]).T
- means, err_lo, err_hi = x[0], x[1], x[2]
- return means, err_lo, err_hi
-
- means, err_lo, err_hi = get_plot_values()
-
- if checkpoint_mode:
- scatter_plot = go.Figure(
- [
- go.Scatter(
- name="Upper Bound",
- x=input_x,
- y=means + err_hi,
- mode="lines",
- marker=dict(color=shade_color),
- line=dict(width=0),
- showlegend=False,
- ),
- go.Scatter(
- name="Lower Bound",
- x=input_x,
- y=means - err_lo,
- marker=dict(color=shade_color),
- line=dict(width=0),
- mode="lines",
- fillcolor=shade_color,
- fill="tonexty",
- showlegend=False,
- ),
- go.Scatter(
- name=line_metric,
- x=input_x,
- y=means,
- mode="lines",
- marker=dict(
- color=line_color,
- size=0,
- line=dict(color=line_color, width=1),
- ),
- ),
- ]
- )
- else:
- scatter_plot = go.Scatter(
- x=input_x,
- y=means,
- error_y=dict(
- type="data",
- symmetric=False,
- array=err_hi,
- arrayminus=err_lo,
- color=bar_color,
- ),
- marker=dict(
- color=marker_color,
- size=15,
- line=dict(color=line_color, width=2),
- ),
- hovertext=get_hovertexts(means, err_lo, err_hi),
- hoverinfo="text+x",
- )
- g = go.FigureWidget(
- data=scatter_plot,
- layout=go.Layout(
- yaxis=dict(
- title="Loss",
- type="log" if log_scale else "linear",
- ),
- plot_bgcolor=background_color,
- ),
- )
-
- return g
diff --git a/src/delphi/dataset/tokenization.py b/src/delphi/tokenization.py
similarity index 100%
rename from src/delphi/dataset/tokenization.py
rename to src/delphi/tokenization.py
diff --git a/src/delphi/utils.py b/src/delphi/utils.py
index 30325cb2..7ffbfc88 100644
--- a/src/delphi/utils.py
+++ b/src/delphi/utils.py
@@ -1,6 +1,9 @@
+from collections.abc import Callable
from typing import cast
+import torch
from datasets import Dataset, Features, Sequence, Value, load_dataset
+from jaxtyping import Float, Int
def hf_split_to_split_name(split: str) -> str:
@@ -55,3 +58,30 @@ def get_all_hf_branch_names(repo_id: str) -> list[str]:
api = HfApi()
refs = api.list_repo_refs(repo_id)
return [branch.name for branch in refs.branches]
+
+
+def gather_logprobs(
+ logprobs: Float[torch.Tensor, "batch seq vocab"],
+ tokens: Int[torch.Tensor, "batch seq"],
+) -> Float[torch.Tensor, "batch seq"]:
+ return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)
+
+
+def get_all_logprobs(
+ model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
+) -> Float[torch.Tensor, "batch seq vocab"]:
+ # batch, seq, vocab
+ logits = model(input_ids).logits
+ return torch.log_softmax(logits, dim=-1)
+
+
+def get_all_and_next_logprobs(
+ model: Callable,
+ input_ids: Int[torch.Tensor, "batch seq"],
+) -> tuple[
+ Float[torch.Tensor, "batch shorter_seq vocab"],
+ Float[torch.Tensor, "batch shorter_seq"],
+]:
+ logprobs = get_all_logprobs(model, input_ids[:, :-1])
+ next_tokens = input_ids[:, 1:]
+ return logprobs, gather_logprobs(logprobs, next_tokens)
diff --git a/tests/eval/test_token_positions.py b/tests/eval/test_token_positions.py
deleted file mode 100644
index c584b931..00000000
--- a/tests/eval/test_token_positions.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from math import isclose
-from typing import cast
-
-import pytest
-from datasets import Dataset
-
-from delphi.eval.token_positions import *
-
-
-@pytest.fixture
-def mock_data():
- token_ids = Dataset.from_dict(
- {"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]}
- ).with_format("torch")
- selected_tokens = [2, 4, 6, 8]
- metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]])
- return token_ids, selected_tokens, metrics
-
-
-def test_get_all_tok_metrics_in_label(mock_data):
- token_ids, selected_tokens, metrics = mock_data
- result = get_all_tok_metrics_in_label(
- token_ids["tokens"],
- selected_tokens,
- metrics,
- )
- # key: (prompt_pos, tok_pos), value: logprob
- expected = {
- (0, 1): 0.45,
- (1, 0): -1.31,
- (1, 2): 0.6,
- (2, 1): 0.8,
- }
-
- # compare keys
- assert result.keys() == expected.keys()
- # compare values
- for k in result:
- assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore
-
- # test with quantile filtering
- result_q = get_all_tok_metrics_in_label(
- token_ids["tokens"], selected_tokens, metrics, q_start=0.6, q_end=1.0
- )
- expected_q = {
- (1, 2): 0.6,
- (2, 1): 0.8,
- }
-
- # compare keys
- assert result_q.keys() == expected_q.keys()
- # compare values
- for k in result_q:
- assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore
diff --git a/tests/eval/test_utils_eval.py b/tests/eval/test_utils_eval.py
deleted file mode 100644
index 54e0034a..00000000
--- a/tests/eval/test_utils_eval.py
+++ /dev/null
@@ -1,84 +0,0 @@
-from math import isclose
-
-import pytest
-import torch
-
-from delphi.eval.utils import dict_filter_quantile, gather_logprobs
-
-
-def test_gather_logprobs():
- # vocab size = 3
- logprobs = torch.tensor(
- [
- # batch 0
- [
- # seq 0
- [0.00, 0.01, 0.02],
- # seq 1
- [0.10, 0.11, 0.12],
- ],
- # batch 1
- [
- # seq 0
- [1.00, 1.01, 1.02],
- # seq 1
- [1.10, 1.11, 1.12],
- ],
- ]
- )
- tokens = torch.tensor(
- [
- # batch 0
- [0, 2],
- # batch 1
- [1, 2],
- ]
- )
- expected_output = torch.tensor(
- [
- # batch 0
- [0.00, 0.12],
- # batch 1
- [1.01, 1.12],
- ]
- )
- result = gather_logprobs(logprobs, tokens)
- assert torch.allclose(result, expected_output)
-
-
-@pytest.mark.filterwarnings(
- "ignore::RuntimeWarning"
-) # ignore warnings from numpy empty slice
-def test_dict_filter_quantile():
- d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5}
- result = dict_filter_quantile(d, 0.2, 0.6)
- expected = {2: 0.2, 3: 0.3}
-
- # compare keys
- assert result.keys() == expected.keys()
- # compare values
- for k in result:
- assert isclose(result[k], expected[k], rel_tol=1e-6)
-
- # test with negative values
- d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5}
- result = dict_filter_quantile(d, 0.2, 0.6)
- expected = {3: -0.3, 4: -0.4}
-
- # compare keys
- assert result.keys() == expected.keys()
- # compare values
- for k in result:
- assert isclose(result[k], expected[k], rel_tol=1e-6)
-
- # test invalid quantile range
- with pytest.raises(ValueError):
- dict_filter_quantile(d, 0.6, 0.2)
- with pytest.raises(ValueError):
- dict_filter_quantile(d, 0.1, 1.1)
- with pytest.raises(ValueError):
- dict_filter_quantile(d, -0.1, 0.6)
-
- # test empty dict, will raise a warning
- result = dict_filter_quantile({}, 0.2, 0.6)
- assert result == {}
diff --git a/tests/test_eval.py b/tests/test_eval.py
new file mode 100644
index 00000000..cdf88413
--- /dev/null
+++ b/tests/test_eval.py
@@ -0,0 +1,91 @@
+from math import isclose
+from typing import cast
+
+import pytest
+import torch
+from datasets import Dataset
+
+from delphi.eval import dict_filter_quantile, get_all_tok_metrics_in_label
+
+
+@pytest.mark.filterwarnings(
+ "ignore::RuntimeWarning"
+) # ignore warnings from numpy empty slice
+def test_dict_filter_quantile():
+ d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5}
+ result = dict_filter_quantile(d, 0.2, 0.6)
+ expected = {2: 0.2, 3: 0.3}
+
+ # compare keys
+ assert result.keys() == expected.keys()
+ # compare values
+ for k in result:
+ assert isclose(result[k], expected[k], rel_tol=1e-6)
+
+ # test with negative values
+ d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5}
+ result = dict_filter_quantile(d, 0.2, 0.6)
+ expected = {3: -0.3, 4: -0.4}
+
+ # compare keys
+ assert result.keys() == expected.keys()
+ # compare values
+ for k in result:
+ assert isclose(result[k], expected[k], rel_tol=1e-6)
+
+ # test invalid quantile range
+ with pytest.raises(ValueError):
+ dict_filter_quantile(d, 0.6, 0.2)
+ with pytest.raises(ValueError):
+ dict_filter_quantile(d, 0.1, 1.1)
+ with pytest.raises(ValueError):
+ dict_filter_quantile(d, -0.1, 0.6)
+
+ # test empty dict, will raise a warning
+ result = dict_filter_quantile({}, 0.2, 0.6)
+ assert result == {}
+
+
+def test_get_all_tok_metrics_in_label():
+ token_ids = Dataset.from_dict(
+ {"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]}
+ ).with_format("torch")
+ selected_tokens = [2, 4, 6, 8]
+ metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]])
+ result = get_all_tok_metrics_in_label(
+ token_ids["tokens"], # type: ignore
+ selected_tokens,
+ metrics,
+ )
+ # key: (prompt_pos, tok_pos), value: logprob
+ expected = {
+ (0, 1): 0.45,
+ (1, 0): -1.31,
+ (1, 2): 0.6,
+ (2, 1): 0.8,
+ }
+
+ # compare keys
+ assert result.keys() == expected.keys()
+ # compare values
+ for k in result:
+ assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore
+
+ # test with quantile filtering
+ result_q = get_all_tok_metrics_in_label(
+ token_ids["tokens"], # type: ignore
+ selected_tokens,
+ metrics,
+ q_start=0.6,
+ q_end=1.0,
+ )
+ expected_q = {
+ (1, 2): 0.6,
+ (2, 1): 0.8,
+ }
+
+ # compare keys
+ assert result_q.keys() == expected_q.keys()
+ # compare values
+ for k in result_q:
+ assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore
diff --git a/tests/dataset/test_tokeniation.py b/tests/test_tokeniation.py
similarity index 97%
rename from tests/dataset/test_tokeniation.py
rename to tests/test_tokeniation.py
index bb4180ba..cc9494b2 100644
--- a/tests/dataset/test_tokeniation.py
+++ b/tests/test_tokeniation.py
@@ -5,7 +5,7 @@
from datasets import Dataset
from transformers import AutoTokenizer
-from delphi.dataset.tokenization import extend_deque, make_new_sample, tokenize_dataset
+from delphi.tokenization import extend_deque, make_new_sample, tokenize_dataset
@pytest.fixture
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 597438ca..79b639ad 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,6 +1,13 @@
-from delphi.utils import hf_split_to_split_name
+import random
+import string
-from .utils import random_string
+import torch
+
+from delphi.utils import gather_logprobs, hf_split_to_split_name
+
+
+def random_string(length: int) -> str:
+ return "".join(random.choices(string.ascii_lowercase, k=length))
def test_hf_split_to_split_name():
@@ -12,3 +19,43 @@ def test_hf_split_to_split_name():
assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name
+
+
+def test_gather_logprobs():
+ # vocab size = 3
+ logprobs = torch.tensor(
+ [
+ # batch 0
+ [
+ # seq 0
+ [0.00, 0.01, 0.02],
+ # seq 1
+ [0.10, 0.11, 0.12],
+ ],
+ # batch 1
+ [
+ # seq 0
+ [1.00, 1.01, 1.02],
+ # seq 1
+ [1.10, 1.11, 1.12],
+ ],
+ ]
+ )
+ tokens = torch.tensor(
+ [
+ # batch 0
+ [0, 2],
+ # batch 1
+ [1, 2],
+ ]
+ )
+ expected_output = torch.tensor(
+ [
+ # batch 0
+ [0.00, 0.12],
+ # batch 1
+ [1.01, 1.12],
+ ]
+ )
+ result = gather_logprobs(logprobs, tokens)
+ assert torch.allclose(result, expected_output)
diff --git a/tests/train/test_train_step.py b/tests/train/test_train_step.py
index 9a7f3456..e06fa1af 100644
--- a/tests/train/test_train_step.py
+++ b/tests/train/test_train_step.py
@@ -8,7 +8,6 @@
from transformers import PreTrainedModel
from delphi import TEST_CONFIGS_DIR
-from delphi.eval.utils import get_all_and_next_logprobs
from delphi.train.config import TrainingConfig
from delphi.train.config.utils import build_config_from_files_and_overrides
from delphi.train.train_step import accumulate_gradients, train_step
@@ -18,6 +17,7 @@
init_model,
setup_determinism,
)
+from delphi.utils import get_all_and_next_logprobs
def load_test_config(preset_name: str) -> TrainingConfig:
diff --git a/tests/utils.py b/tests/utils.py
deleted file mode 100644
index ed81b58a..00000000
--- a/tests/utils.py
+++ /dev/null
@@ -1,6 +0,0 @@
-import random
-import string
-
-
-def random_string(length: int) -> str:
- return "".join(random.choices(string.ascii_lowercase, k=length))