diff --git a/notebooks/eval_notebook.ipynb b/notebooks/eval_notebook.ipynb index 864a1f1e..daed1f52 100644 --- a/notebooks/eval_notebook.ipynb +++ b/notebooks/eval_notebook.ipynb @@ -68,12 +68,12 @@ "data": { "application/vnd.holoviews_exec.v0+json": "", "text/html": [ - "
\n", - "
\n", + "
\n", + "
\n", "
\n", " + """ + display(HTML(html_str)) + + +def token_selector( + vocab_map: dict[str, int] +) -> tuple[pn.widgets.MultiChoice, list[int]]: + tokens = list(vocab_map.keys()) + token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens) + token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)] + + def update_tokens(event): + token_ids.clear() + token_ids.extend([vocab_map[token] for token in event.new]) + + token_selector_.param.watch(update_tokens, "value") + return token_selector_, token_ids + + +def calc_model_group_stats( + tokenized_corpus_dataset: Dataset, + logprobs_by_dataset: dict[str, torch.Tensor], + selected_tokens: list[int], +) -> dict[str, dict[str, float]]: + """ + For each (model, token group) pair, calculate useful stats (for visualization) + + args: + - tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] + - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]} + - selected_tokens: a list of selected token IDs, e.g. [46, 402, ...] + + returns: a dict of model names as keys and stats dict as values + e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} + + Stats calculated: mean, median, min, max, 25th percentile, 75th percentile + """ + model_group_stats = {} + for model in logprobs_by_dataset: + model_logprobs = [] + print(f"Processing model {model}") + dataset = logprobs_by_dataset[model] + for ix_doc_lp, document_lps in enumerate(dataset): + tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"] + for ix_token, token in enumerate(tokens): + if ix_token == 0: # skip the first token, which isn't predicted + continue + logprob = document_lps[ix_token].item() + if token in selected_tokens: + model_logprobs.append(logprob) + + if model_logprobs: + model_group_stats[model] = { + "mean": np.mean(model_logprobs), + "median": np.median(model_logprobs), + "min": np.min(model_logprobs), + "max": np.max(model_logprobs), + "25th": np.percentile(model_logprobs, 25), + "75th": np.percentile(model_logprobs, 75), + } + return model_group_stats + + +def dict_filter_quantile( + d: dict[Any, float], q_start: float, q_end: float +) -> dict[Any, float]: + if not (0 <= q_start < q_end <= 1): + raise ValueError("Invalid quantile range") + q_start_val = np.nanquantile(list(d.values()), q_start) + q_end_val = np.nanquantile(list(d.values()), q_end) + return { + k: v for k, v in d.items() if q_start_val <= v <= q_end_val and not np.isnan(v) + } + + +def get_all_tok_metrics_in_label( + token_ids: Int[torch.Tensor, "prompt pos"], + selected_tokens: list[int], + metrics: torch.Tensor, + q_start: Optional[float] = None, + q_end: Optional[float] = None, +) -> dict[tuple[int, int], float]: + """ + From the token_map, get all the positions of the tokens that have a certain label. + We don't use the token_map because for sampling purposes, iterating through token_ids is more efficient. + Optionally, filter the tokens based on the quantile range of the metrics. + + Args: + - token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]} + - selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...] + - metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...]) + - q_start (float): the start of the quantile range to filter the metrics e.g. 0.1 + - q_end (float): the end of the quantile range to filter the metrics e.g. 0.9 + + Returns: + - tok_positions (dict[tuple[int, int], Number]): dictionary of token positions and their corresponding metrics + """ + + # check if metrics have the same dimensions as token_ids + if metrics.shape != token_ids.shape: + raise ValueError( + f"Expected metrics to have the same shape as token_ids, but got {metrics.shape} and {token_ids.shape} instead." + ) + + tok_positions = {} + for prompt_pos, prompt in enumerate(token_ids.numpy()): + for tok_pos, tok in enumerate(prompt): + if tok in selected_tokens: + tok_positions[(prompt_pos, tok_pos)] = metrics[ + prompt_pos, tok_pos + ].item() + + if q_start is not None and q_end is not None: + tok_positions = dict_filter_quantile(tok_positions, q_start, q_end) + + return tok_positions + + +def visualize_selected_tokens( + input: dict[str | int, tuple[float, float, float]], + log_scale=False, + line_metric="Means", + checkpoint_mode=True, + shade_color="rgba(68, 68, 68, 0.3)", + line_color="rgb(31, 119, 180)", + bar_color="purple", + marker_color="SkyBlue", + background_color="AliceBlue", +) -> go.FigureWidget: + input_x = list(input.keys()) + + def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: + return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] + + def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + x = np.array([input[x] for x in input_x]).T + means, err_lo, err_hi = x[0], x[1], x[2] + return means, err_lo, err_hi + + means, err_lo, err_hi = get_plot_values() + + if checkpoint_mode: + scatter_plot = go.Figure( + [ + go.Scatter( + name="Upper Bound", + x=input_x, + y=means + err_hi, + mode="lines", + marker=dict(color=shade_color), + line=dict(width=0), + showlegend=False, + ), + go.Scatter( + name="Lower Bound", + x=input_x, + y=means - err_lo, + marker=dict(color=shade_color), + line=dict(width=0), + mode="lines", + fillcolor=shade_color, + fill="tonexty", + showlegend=False, + ), + go.Scatter( + name=line_metric, + x=input_x, + y=means, + mode="lines", + marker=dict( + color=line_color, + size=0, + line=dict(color=line_color, width=1), + ), + ), + ] + ) + else: + scatter_plot = go.Scatter( + x=input_x, + y=means, + error_y=dict( + type="data", + symmetric=False, + array=err_hi, + arrayminus=err_lo, + color=bar_color, + ), + marker=dict( + color=marker_color, + size=15, + line=dict(color=line_color, width=2), + ), + hovertext=get_hovertexts(means, err_lo, err_hi), + hoverinfo="text+x", + ) + g = go.FigureWidget( + data=scatter_plot, + layout=go.Layout( + yaxis=dict( + title="Loss", + type="log" if log_scale else "linear", + ), + plot_bgcolor=background_color, + ), + ) + + return g diff --git a/src/delphi/eval/__init__.py b/src/delphi/eval/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py deleted file mode 100644 index faab8a02..00000000 --- a/src/delphi/eval/calc_model_group_stats.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import torch -from datasets import Dataset -from jaxtyping import Float - - -def calc_model_group_stats( - tokenized_corpus_dataset: Dataset, - logprobs_by_dataset: dict[str, torch.Tensor], - selected_tokens: list[int], -) -> dict[str, dict[str, float]]: - """ - For each (model, token group) pair, calculate useful stats (for visualization) - - args: - - tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] - - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]} - - selected_tokens: a list of selected token IDs, e.g. [46, 402, ...] - - returns: a dict of model names as keys and stats dict as values - e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} - - Stats calculated: mean, median, min, max, 25th percentile, 75th percentile - """ - model_group_stats = {} - for model in logprobs_by_dataset: - model_logprobs = [] - print(f"Processing model {model}") - dataset = logprobs_by_dataset[model] - for ix_doc_lp, document_lps in enumerate(dataset): - tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"] - for ix_token, token in enumerate(tokens): - if ix_token == 0: # skip the first token, which isn't predicted - continue - logprob = document_lps[ix_token].item() - if token in selected_tokens: - model_logprobs.append(logprob) - - if model_logprobs: - model_group_stats[model] = { - "mean": np.mean(model_logprobs), - "median": np.median(model_logprobs), - "min": np.min(model_logprobs), - "max": np.max(model_logprobs), - "25th": np.percentile(model_logprobs, 25), - "75th": np.percentile(model_logprobs, 75), - } - return model_group_stats diff --git a/src/delphi/eval/token_positions.py b/src/delphi/eval/token_positions.py deleted file mode 100644 index a98af761..00000000 --- a/src/delphi/eval/token_positions.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Optional - -import torch -from jaxtyping import Int - -from delphi.eval.utils import dict_filter_quantile - - -def get_all_tok_metrics_in_label( - token_ids: Int[torch.Tensor, "prompt pos"], - selected_tokens: list[int], - metrics: torch.Tensor, - q_start: Optional[float] = None, - q_end: Optional[float] = None, -) -> dict[tuple[int, int], float]: - """ - From the token_map, get all the positions of the tokens that have a certain label. - We don't use the token_map because for sampling purposes, iterating through token_ids is more efficient. - Optionally, filter the tokens based on the quantile range of the metrics. - - Args: - - token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]} - - selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...] - - metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...]) - - q_start (float): the start of the quantile range to filter the metrics e.g. 0.1 - - q_end (float): the end of the quantile range to filter the metrics e.g. 0.9 - - Returns: - - tok_positions (dict[tuple[int, int], Number]): dictionary of token positions and their corresponding metrics - """ - - # check if metrics have the same dimensions as token_ids - if metrics.shape != token_ids.shape: - raise ValueError( - f"Expected metrics to have the same shape as token_ids, but got {metrics.shape} and {token_ids.shape} instead." - ) - - tok_positions = {} - for prompt_pos, prompt in enumerate(token_ids.numpy()): - for tok_pos, tok in enumerate(prompt): - if tok in selected_tokens: - tok_positions[(prompt_pos, tok_pos)] = metrics[ - prompt_pos, tok_pos - ].item() - - if q_start is not None and q_end is not None: - tok_positions = dict_filter_quantile(tok_positions, q_start, q_end) - - return tok_positions diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py deleted file mode 100644 index 0026e7a7..00000000 --- a/src/delphi/eval/utils.py +++ /dev/null @@ -1,76 +0,0 @@ -from collections.abc import Callable -from typing import Any - -import numpy as np -import torch -from jaxtyping import Float, Int -from transformers import PreTrainedModel - - -def get_all_logprobs( - model: Callable, input_ids: Int[torch.Tensor, "batch seq"] -) -> Float[torch.Tensor, "batch seq vocab"]: - # batch, seq, vocab - logits = model(input_ids).logits - return torch.log_softmax(logits, dim=-1) - - -# convenience wrapper for calling on a single sample -def get_single_logprobs( - model: Callable, input_ids: Int[torch.Tensor, "seq"] -) -> Float[torch.Tensor, "seq vocab"]: - return get_all_logprobs(model, input_ids.unsqueeze(0))[0] - - -def gather_logprobs( - logprobs: Float[torch.Tensor, "batch seq vocab"], - tokens: Int[torch.Tensor, "batch seq"], -) -> Float[torch.Tensor, "batch seq"]: - return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1) - - -def get_all_and_next_logprobs( - model: Callable, - input_ids: Int[torch.Tensor, "batch seq"], -) -> tuple[ - Float[torch.Tensor, "batch shorter_seq vocab"], - Float[torch.Tensor, "batch shorter_seq"], -]: - logprobs = get_all_logprobs(model, input_ids[:, :-1]) - next_tokens = input_ids[:, 1:] - return logprobs, gather_logprobs(logprobs, next_tokens) - - -def get_all_and_next_logprobs_single( - model: Callable, - input_ids: Int[torch.Tensor, "seq"], -) -> tuple[ - Float[torch.Tensor, "shorter_seq vocab"], - Float[torch.Tensor, "shorter_seq"], -]: - all_logprobs, next_logprobs = get_all_and_next_logprobs( - model, input_ids.unsqueeze(0) - ) - return all_logprobs[0], next_logprobs[0] - - -def get_next_and_top_k_probs( - model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3 -) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]: - all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids) - all_probs = torch.exp(all_logprobs) - next_probs = torch.exp(next_logprobs) - top_k = torch.topk(all_probs, k, dim=-1) - return next_probs, top_k - - -def dict_filter_quantile( - d: dict[Any, float], q_start: float, q_end: float -) -> dict[Any, float]: - if not (0 <= q_start < q_end <= 1): - raise ValueError("Invalid quantile range") - q_start_val = np.nanquantile(list(d.values()), q_start) - q_end_val = np.nanquantile(list(d.values()), q_end) - return { - k: v for k, v in d.items() if q_start_val <= v <= q_end_val and not np.isnan(v) - } diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py deleted file mode 100644 index cea76e88..00000000 --- a/src/delphi/eval/vis.py +++ /dev/null @@ -1,192 +0,0 @@ -import math -import random -import uuid -from typing import cast - -import numpy as np -import panel as pn -import torch -from IPython.core.display import HTML -from IPython.core.display_functions import display -from jaxtyping import Float, Int -from transformers import PreTrainedTokenizerBase - - -def single_loss_diff_to_color(loss_diff: float) -> str: - # if loss_diff is negative, we want the color to be red - # if loss_diff is positive, we want the color to be green - # if loss_diff is 0, we want the color to be white - # the color should be more intense the larger the absolute value of loss_diff - - def sigmoid(x: float) -> float: - return 1 / (1 + math.exp(-x)) - - scaled_loss_diff = sigmoid(loss_diff) # scale to 0-1 - - if scaled_loss_diff < 0.5: # red - red_val = 255 - green_blue_val = min(int(255 * 2 * scaled_loss_diff), 255) - return f"rgb({red_val}, {green_blue_val}, {green_blue_val})" - else: # green - green_val = 255 - red_blue_val = min(int(255 * 2 * (1 - scaled_loss_diff)), 255) - return f"rgb({red_blue_val}, {green_val}, {red_blue_val})" - - -def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str: - tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n") - prob_str = f"{prob:.2%}" - return f"{prob_str:>6} |{tok_str}|" - - -def token_to_html( - token: int, - tokenizer: PreTrainedTokenizerBase, - bg_color: str, - data: dict, - class_name: str = "token", -) -> str: - data = data or {} # equivalent to if not data: data = {} - # non-breakable space, w/o it leading spaces wouldn't be displayed - str_token = tokenizer.decode(token).replace(" ", " ") - - # background or user-select (for \n) goes here - specific_styles = {} - # for now just adds line break or doesn't - br = "" - - if bg_color: - specific_styles["background-color"] = bg_color - if str_token == "\n": - # replace new line character with two characters: \ and n - str_token = r"\n" - # add line break in html - br += "
" - # this is so we can copy the prompt without "\n"s - specific_styles["user-select"] = "none" - str_token = str_token.replace("<", "<").replace(">", ">") - - style_str = data_str = "" - # converting style dict into the style attribute - if specific_styles: - inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items()) - style_str = f" style='{inside_style_str}'" - if data: - data_str = "".join( - f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items() - ) - return f"
{str_token}
{br}" - - -_token_style = { - "border": "1px solid #888", - "display": "inline-block", - # each character of the same width, so we can easily spot a space - "font-family": "monospace", - "font-size": "14px", - "color": "black", - "background-color": "white", - "margin": "1px 0px 1px 1px", - "padding": "0px 1px 1px 1px", -} -_token_emphasized_style = { - "border": "3px solid #888", - "display": "inline-block", - "font-family": "monospace", - "font-size": "14px", - "color": "black", - "background-color": "white", - "margin": "1px 0px 1px 1px", - "padding": "0px 1px 1px 1px", -} -_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()]) -_token_emphasized_style_str = " ".join( - [f"{k}: {v};" for k, v in _token_emphasized_style.items()] -) - - -def vis_pos_map( - pos_list: list[tuple[int, int]], - selected_tokens: list[int], - metrics: Float[torch.Tensor, "prompt pos"], - token_ids: Int[torch.Tensor, "prompt pos"], - tokenizer: PreTrainedTokenizerBase, -): - """ - Randomly sample from pos_map and visualize the loss diff at the corresponding position. - """ - - token_htmls = [] - unique_id = str(uuid.uuid4()) - token_class = f"pretoken_{unique_id}" - selected_token_class = f"token_{unique_id}" - hover_div_id = f"hover_info_{unique_id}" - - # choose a random keys from pos_map - key = random.choice(pos_list) - - prompt, pos = key - all_toks = token_ids[prompt][: pos + 1] - - for i in range(all_toks.shape[0]): - token_id = cast(int, all_toks[i].item()) - value = metrics[prompt][i].item() - token_htmls.append( - token_to_html( - token_id, - tokenizer, - bg_color="white" - if np.isnan(value) - else single_loss_diff_to_color(value), - data={"loss-diff": f"{value:.2f}"}, - class_name=token_class - if token_id not in selected_tokens - else selected_token_class, - ) - ) - - # add break line - token_htmls.append("

") - - html_str = f""" - - {"".join(token_htmls)}
- - """ - display(HTML(html_str)) - - -def token_selector( - vocab_map: dict[str, int] -) -> tuple[pn.widgets.MultiChoice, list[int]]: - tokens = list(vocab_map.keys()) - token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens) - token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)] - - def update_tokens(event): - token_ids.clear() - token_ids.extend([vocab_map[token] for token in event.new]) - - token_selector_.param.watch(update_tokens, "value") - return token_selector_, token_ids diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py deleted file mode 100644 index e5d735f4..00000000 --- a/src/delphi/eval/vis_per_token_model.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Union - -import numpy as np -import plotly.graph_objects as go - - -def visualize_selected_tokens( - input: dict[Union[str, int], tuple[float, float, float]], - log_scale=False, - line_metric="Means", - checkpoint_mode=True, - shade_color="rgba(68, 68, 68, 0.3)", - line_color="rgb(31, 119, 180)", - bar_color="purple", - marker_color="SkyBlue", - background_color="AliceBlue", -) -> go.FigureWidget: - input_x = list(input.keys()) - - def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: - return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] - - def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]: - x = np.array([input[x] for x in input_x]).T - means, err_lo, err_hi = x[0], x[1], x[2] - return means, err_lo, err_hi - - means, err_lo, err_hi = get_plot_values() - - if checkpoint_mode: - scatter_plot = go.Figure( - [ - go.Scatter( - name="Upper Bound", - x=input_x, - y=means + err_hi, - mode="lines", - marker=dict(color=shade_color), - line=dict(width=0), - showlegend=False, - ), - go.Scatter( - name="Lower Bound", - x=input_x, - y=means - err_lo, - marker=dict(color=shade_color), - line=dict(width=0), - mode="lines", - fillcolor=shade_color, - fill="tonexty", - showlegend=False, - ), - go.Scatter( - name=line_metric, - x=input_x, - y=means, - mode="lines", - marker=dict( - color=line_color, - size=0, - line=dict(color=line_color, width=1), - ), - ), - ] - ) - else: - scatter_plot = go.Scatter( - x=input_x, - y=means, - error_y=dict( - type="data", - symmetric=False, - array=err_hi, - arrayminus=err_lo, - color=bar_color, - ), - marker=dict( - color=marker_color, - size=15, - line=dict(color=line_color, width=2), - ), - hovertext=get_hovertexts(means, err_lo, err_hi), - hoverinfo="text+x", - ) - g = go.FigureWidget( - data=scatter_plot, - layout=go.Layout( - yaxis=dict( - title="Loss", - type="log" if log_scale else "linear", - ), - plot_bgcolor=background_color, - ), - ) - - return g diff --git a/src/delphi/dataset/tokenization.py b/src/delphi/tokenization.py similarity index 100% rename from src/delphi/dataset/tokenization.py rename to src/delphi/tokenization.py diff --git a/src/delphi/utils.py b/src/delphi/utils.py index 30325cb2..7ffbfc88 100644 --- a/src/delphi/utils.py +++ b/src/delphi/utils.py @@ -1,6 +1,9 @@ +from collections.abc import Callable from typing import cast +import torch from datasets import Dataset, Features, Sequence, Value, load_dataset +from jaxtyping import Float, Int def hf_split_to_split_name(split: str) -> str: @@ -55,3 +58,30 @@ def get_all_hf_branch_names(repo_id: str) -> list[str]: api = HfApi() refs = api.list_repo_refs(repo_id) return [branch.name for branch in refs.branches] + + +def gather_logprobs( + logprobs: Float[torch.Tensor, "batch seq vocab"], + tokens: Int[torch.Tensor, "batch seq"], +) -> Float[torch.Tensor, "batch seq"]: + return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1) + + +def get_all_logprobs( + model: Callable, input_ids: Int[torch.Tensor, "batch seq"] +) -> Float[torch.Tensor, "batch seq vocab"]: + # batch, seq, vocab + logits = model(input_ids).logits + return torch.log_softmax(logits, dim=-1) + + +def get_all_and_next_logprobs( + model: Callable, + input_ids: Int[torch.Tensor, "batch seq"], +) -> tuple[ + Float[torch.Tensor, "batch shorter_seq vocab"], + Float[torch.Tensor, "batch shorter_seq"], +]: + logprobs = get_all_logprobs(model, input_ids[:, :-1]) + next_tokens = input_ids[:, 1:] + return logprobs, gather_logprobs(logprobs, next_tokens) diff --git a/tests/eval/test_token_positions.py b/tests/eval/test_token_positions.py deleted file mode 100644 index c584b931..00000000 --- a/tests/eval/test_token_positions.py +++ /dev/null @@ -1,54 +0,0 @@ -from math import isclose -from typing import cast - -import pytest -from datasets import Dataset - -from delphi.eval.token_positions import * - - -@pytest.fixture -def mock_data(): - token_ids = Dataset.from_dict( - {"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]} - ).with_format("torch") - selected_tokens = [2, 4, 6, 8] - metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]]) - return token_ids, selected_tokens, metrics - - -def test_get_all_tok_metrics_in_label(mock_data): - token_ids, selected_tokens, metrics = mock_data - result = get_all_tok_metrics_in_label( - token_ids["tokens"], - selected_tokens, - metrics, - ) - # key: (prompt_pos, tok_pos), value: logprob - expected = { - (0, 1): 0.45, - (1, 0): -1.31, - (1, 2): 0.6, - (2, 1): 0.8, - } - - # compare keys - assert result.keys() == expected.keys() - # compare values - for k in result: - assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore - - # test with quantile filtering - result_q = get_all_tok_metrics_in_label( - token_ids["tokens"], selected_tokens, metrics, q_start=0.6, q_end=1.0 - ) - expected_q = { - (1, 2): 0.6, - (2, 1): 0.8, - } - - # compare keys - assert result_q.keys() == expected_q.keys() - # compare values - for k in result_q: - assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore diff --git a/tests/eval/test_utils_eval.py b/tests/eval/test_utils_eval.py deleted file mode 100644 index 54e0034a..00000000 --- a/tests/eval/test_utils_eval.py +++ /dev/null @@ -1,84 +0,0 @@ -from math import isclose - -import pytest -import torch - -from delphi.eval.utils import dict_filter_quantile, gather_logprobs - - -def test_gather_logprobs(): - # vocab size = 3 - logprobs = torch.tensor( - [ - # batch 0 - [ - # seq 0 - [0.00, 0.01, 0.02], - # seq 1 - [0.10, 0.11, 0.12], - ], - # batch 1 - [ - # seq 0 - [1.00, 1.01, 1.02], - # seq 1 - [1.10, 1.11, 1.12], - ], - ] - ) - tokens = torch.tensor( - [ - # batch 0 - [0, 2], - # batch 1 - [1, 2], - ] - ) - expected_output = torch.tensor( - [ - # batch 0 - [0.00, 0.12], - # batch 1 - [1.01, 1.12], - ] - ) - result = gather_logprobs(logprobs, tokens) - assert torch.allclose(result, expected_output) - - -@pytest.mark.filterwarnings( - "ignore::RuntimeWarning" -) # ignore warnings from numpy empty slice -def test_dict_filter_quantile(): - d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5} - result = dict_filter_quantile(d, 0.2, 0.6) - expected = {2: 0.2, 3: 0.3} - - # compare keys - assert result.keys() == expected.keys() - # compare values - for k in result: - assert isclose(result[k], expected[k], rel_tol=1e-6) - - # test with negative values - d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5} - result = dict_filter_quantile(d, 0.2, 0.6) - expected = {3: -0.3, 4: -0.4} - - # compare keys - assert result.keys() == expected.keys() - # compare values - for k in result: - assert isclose(result[k], expected[k], rel_tol=1e-6) - - # test invalid quantile range - with pytest.raises(ValueError): - dict_filter_quantile(d, 0.6, 0.2) - with pytest.raises(ValueError): - dict_filter_quantile(d, 0.1, 1.1) - with pytest.raises(ValueError): - dict_filter_quantile(d, -0.1, 0.6) - - # test empty dict, will raise a warning - result = dict_filter_quantile({}, 0.2, 0.6) - assert result == {} diff --git a/tests/test_eval.py b/tests/test_eval.py new file mode 100644 index 00000000..cdf88413 --- /dev/null +++ b/tests/test_eval.py @@ -0,0 +1,91 @@ +from math import isclose +from typing import cast + +import pytest +import torch +from datasets import Dataset + +from delphi.eval import dict_filter_quantile, get_all_tok_metrics_in_label + + +@pytest.mark.filterwarnings( + "ignore::RuntimeWarning" +) # ignore warnings from numpy empty slice +def test_dict_filter_quantile(): + d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5} + result = dict_filter_quantile(d, 0.2, 0.6) + expected = {2: 0.2, 3: 0.3} + + # compare keys + assert result.keys() == expected.keys() + # compare values + for k in result: + assert isclose(result[k], expected[k], rel_tol=1e-6) + + # test with negative values + d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5} + result = dict_filter_quantile(d, 0.2, 0.6) + expected = {3: -0.3, 4: -0.4} + + # compare keys + assert result.keys() == expected.keys() + # compare values + for k in result: + assert isclose(result[k], expected[k], rel_tol=1e-6) + + # test invalid quantile range + with pytest.raises(ValueError): + dict_filter_quantile(d, 0.6, 0.2) + with pytest.raises(ValueError): + dict_filter_quantile(d, 0.1, 1.1) + with pytest.raises(ValueError): + dict_filter_quantile(d, -0.1, 0.6) + + # test empty dict, will raise a warning + result = dict_filter_quantile({}, 0.2, 0.6) + assert result == {} + + +def test_get_all_tok_metrics_in_label(): + token_ids = Dataset.from_dict( + {"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]} + ).with_format("torch") + selected_tokens = [2, 4, 6, 8] + metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]]) + result = get_all_tok_metrics_in_label( + token_ids["tokens"], # type: ignore + selected_tokens, + metrics, + ) + # key: (prompt_pos, tok_pos), value: logprob + expected = { + (0, 1): 0.45, + (1, 0): -1.31, + (1, 2): 0.6, + (2, 1): 0.8, + } + + # compare keys + assert result.keys() == expected.keys() + # compare values + for k in result: + assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore + + # test with quantile filtering + result_q = get_all_tok_metrics_in_label( + token_ids["tokens"], # type: ignore + selected_tokens, + metrics, + q_start=0.6, + q_end=1.0, + ) + expected_q = { + (1, 2): 0.6, + (2, 1): 0.8, + } + + # compare keys + assert result_q.keys() == expected_q.keys() + # compare values + for k in result_q: + assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore diff --git a/tests/dataset/test_tokeniation.py b/tests/test_tokeniation.py similarity index 97% rename from tests/dataset/test_tokeniation.py rename to tests/test_tokeniation.py index bb4180ba..cc9494b2 100644 --- a/tests/dataset/test_tokeniation.py +++ b/tests/test_tokeniation.py @@ -5,7 +5,7 @@ from datasets import Dataset from transformers import AutoTokenizer -from delphi.dataset.tokenization import extend_deque, make_new_sample, tokenize_dataset +from delphi.tokenization import extend_deque, make_new_sample, tokenize_dataset @pytest.fixture diff --git a/tests/test_utils.py b/tests/test_utils.py index 597438ca..79b639ad 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,13 @@ -from delphi.utils import hf_split_to_split_name +import random +import string -from .utils import random_string +import torch + +from delphi.utils import gather_logprobs, hf_split_to_split_name + + +def random_string(length: int) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) def test_hf_split_to_split_name(): @@ -12,3 +19,43 @@ def test_hf_split_to_split_name(): assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name + + +def test_gather_logprobs(): + # vocab size = 3 + logprobs = torch.tensor( + [ + # batch 0 + [ + # seq 0 + [0.00, 0.01, 0.02], + # seq 1 + [0.10, 0.11, 0.12], + ], + # batch 1 + [ + # seq 0 + [1.00, 1.01, 1.02], + # seq 1 + [1.10, 1.11, 1.12], + ], + ] + ) + tokens = torch.tensor( + [ + # batch 0 + [0, 2], + # batch 1 + [1, 2], + ] + ) + expected_output = torch.tensor( + [ + # batch 0 + [0.00, 0.12], + # batch 1 + [1.01, 1.12], + ] + ) + result = gather_logprobs(logprobs, tokens) + assert torch.allclose(result, expected_output) diff --git a/tests/train/test_train_step.py b/tests/train/test_train_step.py index 9a7f3456..e06fa1af 100644 --- a/tests/train/test_train_step.py +++ b/tests/train/test_train_step.py @@ -8,7 +8,6 @@ from transformers import PreTrainedModel from delphi import TEST_CONFIGS_DIR -from delphi.eval.utils import get_all_and_next_logprobs from delphi.train.config import TrainingConfig from delphi.train.config.utils import build_config_from_files_and_overrides from delphi.train.train_step import accumulate_gradients, train_step @@ -18,6 +17,7 @@ init_model, setup_determinism, ) +from delphi.utils import get_all_and_next_logprobs def load_test_config(preset_name: str) -> TrainingConfig: diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index ed81b58a..00000000 --- a/tests/utils.py +++ /dev/null @@ -1,6 +0,0 @@ -import random -import string - - -def random_string(length: int) -> str: - return "".join(random.choices(string.ascii_lowercase, k=length))