Skip to content

Commit

Permalink
stale eval code purge
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 22, 2024
1 parent 330e301 commit 972fc81
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 284 deletions.
91 changes: 0 additions & 91 deletions src/delphi/eval/compare_models.py

This file was deleted.

26 changes: 0 additions & 26 deletions src/delphi/eval/constants.py

This file was deleted.

61 changes: 2 additions & 59 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import logging
from collections.abc import Callable
from typing import Any, cast
from typing import Any

import numpy as np
import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from delphi.eval import constants
from transformers import PreTrainedModel


def get_all_logprobs(
Expand Down Expand Up @@ -68,59 +64,6 @@ def get_next_and_top_k_probs(
return next_probs, top_k


def load_delphi_dataset(dataset_name: str, split: str, slice: str = "") -> Dataset:
# check that split is either "train" or "validation"
if split not in ["train", "validation"]:
raise ValueError(f"Split must be either 'train' or 'validation', not {split}")
if "/" not in dataset_name:
dataset_name = f"delphi-suite/{dataset_name}"
data_files_str = f"data/{split}-*.parquet"
dataset = load_dataset(
dataset_name,
data_files=data_files_str,
verification_mode="no_checks",
# Currently, load_dataset returns a dataset dict *unless* a split is specified,
# EVEN IF NO SPLIT WITHIN THE DATA FILES SPECIFIED. If there's no split arg,
# huggingface just just says everything is in the "train" split and returns {"train": dataset}.
# In our case the data_files glob already specifies just the validation files, so we
# shouldn't need to specify a split. But we do need to specify a split to get a dataset object,
# or we'd get a Dataset dict. See https://github.com/huggingface/datasets/issues/5189
split=f"train{slice}",
)
dataset = cast(Dataset, dataset)
logging.info(f" Loaded {data_files_str} ({len(dataset)} entries)")
return dataset


def load_validation_dataset(dataset_name: str, slice: str = "") -> Dataset:
return load_delphi_dataset(dataset_name, "validation", slice)


def load_train_dataset(dataset_name: str, slice: str = "") -> Dataset:
return load_delphi_dataset(dataset_name, "train", slice)


def tokenize(
tokenizer: PreTrainedTokenizerBase, sample_txt: str
) -> Int[torch.Tensor, "seq"]:
# supposedly this can be different than prepending the bos token id
return cast(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)


def load_logprob_dataset(model: str):
return load_dataset(f"transcendingvictor/{model}-validation-logprobs")


def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]:
return {
model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] # type: ignore
for model in constants.LLAMA2_MODELS
}


def dict_filter_quantile(
d: dict[Any, float], q_start: float, q_end: float
) -> dict[Any, float]:
Expand Down
75 changes: 0 additions & 75 deletions src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@
from transformers import PreTrainedTokenizerBase


def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]:
# for the endoftext token
# no prediction, no color
colors = ["white"]
for p in probs.tolist():
red_gap = 150 # the higher it is, the less red the tokens will be
green_blue_val = red_gap + int((255 - red_gap) * (1 - p))
colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})")
return colors


def single_loss_diff_to_color(loss_diff: float) -> str:
# if loss_diff is negative, we want the color to be red
# if loss_diff is positive, we want the color to be green
Expand Down Expand Up @@ -116,70 +105,6 @@ def token_to_html(
)


def vis_sample_prediction_probs(
sample_tok: Int[torch.Tensor, "pos"],
correct_probs: Float[torch.Tensor, "pos"],
top_k_probs: torch.return_types.topk,
tokenizer: PreTrainedTokenizerBase,
) -> str:
colors = probs_to_colors(correct_probs)
token_htmls = []

# Generate a unique ID for this instance (so we can have multiple instances on the same page)
unique_id = str(uuid.uuid4())

token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"

for i in range(sample_tok.shape[0]):
tok = cast(int, sample_tok[i].item())
data = {}
if i > 0:
correct_prob = correct_probs[i - 1].item()
data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer)
top_k_probs_tokens = top_k_probs.indices[i - 1]
top_k_probs_values = top_k_probs.values[i - 1]
for j in range(top_k_probs_tokens.shape[0]):
top_tok = top_k_probs_tokens[j].item()
top_tok = cast(int, top_tok)
top_prob = top_k_probs_values[j].item()
data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)

token_htmls.append(
token_to_html(
tok, tokenizer, bg_color=colors[i], data=data, class_name=token_class
)
)

html_str = f"""
<style>.{token_class} {{ {_token_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
{"".join(token_htmls)} <div id='{hover_div_id}'></div>
<script>
(function() {{
var token_divs = document.querySelectorAll('.{token_class}');
var hover_info = document.getElementById('{hover_div_id}');
token_divs.forEach(function(token_div) {{
token_div.addEventListener('mousemove', function(e) {{
hover_info.innerHTML = ""
for( var d in this.dataset) {{
hover_info.innerHTML += "<b>" + d + "</b> ";
hover_info.innerHTML += this.dataset[d] + "<br>";
}}
}});
token_div.addEventListener('mouseout', function(e) {{
hover_info.innerHTML = ""
}});
}});
}})();
</script>
"""
display(HTML(html_str))
return html_str


def vis_pos_map(
pos_list: list[tuple[int, int]],
selected_tokens: list[int],
Expand Down
23 changes: 0 additions & 23 deletions tests/eval/test_compare_models.py

This file was deleted.

11 changes: 1 addition & 10 deletions tests/eval/test_utils_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import pytest
import torch

from delphi.eval.utils import (
dict_filter_quantile,
gather_logprobs,
load_validation_dataset,
)
from delphi.eval.utils import dict_filter_quantile, gather_logprobs


def test_gather_logprobs():
Expand Down Expand Up @@ -50,11 +46,6 @@ def test_gather_logprobs():
assert torch.allclose(result, expected_output)


def test_load_validation_dataset():
text = load_validation_dataset("tinystories-v2-clean")
tokenized = load_validation_dataset("tinystories-v2-clean-tokenized-v0")


@pytest.mark.filterwarnings(
"ignore::RuntimeWarning"
) # ignore warnings from numpy empty slice
Expand Down

0 comments on commit 972fc81

Please sign in to comment.