Skip to content

Commit

Permalink
Add compare_models.py (w/ test)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jai committed Feb 1, 2024
1 parent 517d545 commit ef666c6
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
76 changes: 76 additions & 0 deletions src/delphi/vis/compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from dataclasses import dataclass

import torch as t
import torch.nn as nn
from jaxtyping import Int

from delphi.vis.utils import get_correct_and_all_probs


@dataclass
class ModelComparison:
correct_prob_base_model: t.Tensor
correct_prob_lift_model: t.Tensor
top_k_tokens_lift_model: t.Tensor
top_k_probs_base_model: t.Tensor
top_k_probs_lift_model: t.Tensor


def _pad_start(tensor: t.Tensor) -> t.Tensor:
value_to_prepend = -1
if len(tensor.shape) == 1:
return t.cat((t.tensor([value_to_prepend]), tensor))
else:
# input: 2D tensor of shape [seq_len - 1, top_k]
pre = t.full((1, tensor.size()[-1]), value_to_prepend)
return t.cat((pre, tensor), dim=0)


def compare_models(
model_a: nn.Module,
model_b: nn.Module,
sample_tok: Int[t.Tensor, "pos"],
top_k: int = 3,
) -> ModelComparison:
"""
Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
Args:
- model_a: The first model (assumed to be the base model)
- model_b: The second model (assumed to be the improved model)
- tokens: The tokenized prompt
- top_k: The number of top token predictions to retrieve (default is 5)
Returns:
- A ModelComparison with tensors for:
- The probabilities of the actual next token according to model A
- The probabilities of the actual next token according to model B
- The top k token predictions according to model B
- The probabilities of these tokens according to model A
- The probabilities of these tokens according to model B
Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
"""
assert (
model_a.device == model_b.device
), "Both models must be on the same device for comparison."

device = model_a.device
sample_tok = sample_tok.to(device)

next_probs_a, probs_a = get_correct_and_all_probs(model_a, sample_tok)
next_probs_b, probs_b = get_correct_and_all_probs(model_b, sample_tok)

top_k_b = t.topk(probs_b, top_k, dim=-1)
top_k_a_probs = t.gather(probs_a, 1, top_k_b.indices)

next_probs_a = _pad_start(next_probs_a)
next_probs_b = _pad_start(next_probs_b)
top_k_b_tokens = _pad_start(top_k_b.indices)
top_k_a_probs = _pad_start(top_k_a_probs)
top_k_b_probs = _pad_start(top_k_b.values)

return ModelComparison(
correct_prob_base_model=next_probs_a,
correct_prob_lift_model=next_probs_b,
top_k_tokens_lift_model=top_k_b_tokens,
top_k_probs_base_model=top_k_a_probs,
top_k_probs_lift_model=top_k_b_probs,
)
7 changes: 7 additions & 0 deletions src/delphi/vis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def get_correct_probs(model, sample_tok):
return probs[range(len(probs)), sample_tok[1:]]


def get_correct_and_all_probs(model, sample_tok):
"""Get probabilities for the actual next token and for all predictions"""
probs = get_probs(model, sample_tok)
correct_probs = probs[range(len(probs)), sample_tok[1:]]
return correct_probs, probs


def get_correct_and_top_probs(model, sample_tok, top_k=3):
"""Get probabilities for the actual next token and for top k predictions"""
probs = get_probs(model, sample_tok)
Expand Down
61 changes: 61 additions & 0 deletions tests/test_vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import torch
from beartype.roar import BeartypeCallHintViolation
from IPython.display import HTML
from transformers import AutoModelForCausalLM, AutoTokenizer

from delphi.vis.compare_models import ModelComparison, compare_models
from delphi.vis.utils import load_orig_ds_txt, tokenize

torch.set_grad_enabled(False)


# define a pytest fixture for the model name
@pytest.fixture
def model_name():
return "roneneldan/TinyStories-1M"


# define a pytest fixture for a default tokenizer using the model_name fixture
@pytest.fixture
def tokenizer(model_name):
return AutoTokenizer.from_pretrained(model_name)


# define a pytest fixture for a default model using the model_name fixture
@pytest.fixture
def model(model_name):
return AutoModelForCausalLM.from_pretrained(model_name)


# define a pytest fixture for the raw dataset
@pytest.fixture
def ds_txt():
return load_orig_ds_txt("validation[:100]")


# define a pytest fixture for the tokenized dataset
@pytest.fixture
def ds_tok(tokenizer, ds_txt):
return [tokenize(tokenizer, txt) for txt in ds_txt]


# define a pytest fixture for a tokenized sample
@pytest.fixture
def sample_tok(ds_tok):
return ds_tok[0]


def test_compare_models(model, sample_tok):
model_instruct = AutoModelForCausalLM.from_pretrained(
"roneneldan/TinyStories-Instruct-1M"
)
K = 3
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
assert isinstance(model_comparison, ModelComparison)

assert model_comparison.correct_prob_base_model.shape == sample_tok.shape
assert model_comparison.correct_prob_lift_model.shape == sample_tok.shape
assert model_comparison.top_k_tokens_lift_model.shape == (sample_tok.shape[0], K)
assert model_comparison.top_k_probs_base_model.shape == (sample_tok.shape[0], K)
assert model_comparison.top_k_probs_lift_model.shape == (sample_tok.shape[0], K)

0 comments on commit ef666c6

Please sign in to comment.