Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
Jai committed Feb 9, 2024
1 parent 06242e1 commit 9ad1bc0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 32 deletions.
75 changes: 52 additions & 23 deletions src/delphi/eval/compare_models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
from dataclasses import dataclass
from typing import cast

import torch
import torch.nn as nn
from jaxtyping import Int # Add the missing import statement
from jaxtyping import Int
from transformers import PreTrainedModel

from delphi.eval.utils import get_correct_and_all_probs


@dataclass
class ModelComparison:
correct_prob_base_model: torch.Tensor
correct_prob_lift_model: torch.Tensor
top_k_tokens_lift_model: torch.Tensor
top_k_probs_base_model: torch.Tensor
top_k_probs_lift_model: torch.Tensor
class ModelId:
model_name: str


def identify_model(model: PreTrainedModel) -> ModelId:
return ModelId(model_name=model.config.name_or_path)


@dataclass
class TokenPrediction:
token: int
base_model_prob: float
lift_model_prob: float


@dataclass
class NextTokenStats:
base_model: ModelId
lift_model: ModelId
next_prediction: TokenPrediction
topk: list[TokenPrediction]


def _pad_start(tensor: torch.Tensor) -> torch.Tensor:
Expand All @@ -30,23 +46,18 @@ def _pad_start(tensor: torch.Tensor) -> torch.Tensor:
def compare_models(
model_a: PreTrainedModel,
model_b: PreTrainedModel,
sample_tok: Int[torch.Tensor, "pos"],
sample_tok: Int[torch.Tensor, "seq"],
top_k: int = 3,
) -> ModelComparison:
) -> list[NextTokenStats]:
"""
Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
Args:
- model_a: The first model (assumed to be the base model)
- model_b: The second model (assumed to be the improved model)
- tokens: The tokenized prompt
- sample_tok: The tokenized prompt
- top_k: The number of top token predictions to retrieve (default is 5)
Returns:
- A ModelComparison with tensors for:
- The probabilities of the actual next token according to model A
- The probabilities of the actual next token according to model B
- The top k token predictions according to model B
- The probabilities of these tokens according to model A
- The probabilities of these tokens according to model B
A list of NextTokenStats objects, one for each token in the prompt.
Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
"""
assert (
Expand All @@ -68,10 +79,28 @@ def compare_models(
top_k_a_probs = _pad_start(top_k_a_probs)
top_k_b_probs = _pad_start(top_k_b.values)

return ModelComparison(
correct_prob_base_model=next_probs_a,
correct_prob_lift_model=next_probs_b,
top_k_tokens_lift_model=top_k_b_tokens,
top_k_probs_base_model=top_k_a_probs,
top_k_probs_lift_model=top_k_b_probs,
)
comparisons = []

for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
):
nts = NextTokenStats(
base_model=identify_model(model_a),
lift_model=identify_model(model_b),
next_prediction=TokenPrediction(
token=int(next_p_a.item()),
base_model_prob=next_p_a.item(),
lift_model_prob=next_p_b.item(),
),
topk=[
TokenPrediction(
token=int(top_toks_b[i].item()),
base_model_prob=top_probs_a[i].item(),
lift_model_prob=top_probs_b[i].item(),
)
for i in range(top_k)
],
)
comparisons.append(nts)

return comparisons
16 changes: 7 additions & 9 deletions tests/test_vis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List

import pytest
import torch
from beartype.roar import BeartypeCallHintViolation
from IPython.display import HTML
from jaxtyping import Int
from transformers import AutoModelForCausalLM, AutoTokenizer

from delphi.eval.compare_models import ModelComparison, compare_models
from delphi.eval.compare_models import ModelComparison, NextTokenStats, compare_models
from delphi.eval.utils import load_text_from_dataset, load_validation_dataset, tokenize

torch.set_grad_enabled(False)
Expand Down Expand Up @@ -52,10 +54,6 @@ def test_compare_models(model, sample_tok):
)
K = 3
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
assert isinstance(model_comparison, ModelComparison)

assert model_comparison.correct_prob_base_model.shape == sample_tok.shape
assert model_comparison.correct_prob_lift_model.shape == sample_tok.shape
assert model_comparison.top_k_tokens_lift_model.shape == (sample_tok.shape[0], K)
assert model_comparison.top_k_probs_base_model.shape == (sample_tok.shape[0], K)
assert model_comparison.top_k_probs_lift_model.shape == (sample_tok.shape[0], K)
assert isinstance(model_comparison[0], NextTokenStats)
assert len(model_comparison) == sample_tok.shape[0]
assert len(model_comparison[0].topk) == K

0 comments on commit 9ad1bc0

Please sign in to comment.