Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
Jai committed Feb 9, 2024
1 parent ef666c6 commit 1242506
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/delphi/vis/compare_models.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
from dataclasses import dataclass

import torch as t
import torch
import torch.nn as nn
from jaxtyping import Int
from jaxtyping import Int # Add the missing import statement

from delphi.vis.utils import get_correct_and_all_probs


@dataclass
class ModelComparison:
correct_prob_base_model: t.Tensor
correct_prob_lift_model: t.Tensor
top_k_tokens_lift_model: t.Tensor
top_k_probs_base_model: t.Tensor
top_k_probs_lift_model: t.Tensor
correct_prob_base_model: torch.Tensor
correct_prob_lift_model: torch.Tensor
top_k_tokens_lift_model: torch.Tensor
top_k_probs_base_model: torch.Tensor
top_k_probs_lift_model: torch.Tensor


def _pad_start(tensor: t.Tensor) -> t.Tensor:
def _pad_start(tensor: torch.Tensor) -> torch.Tensor:
value_to_prepend = -1
if len(tensor.shape) == 1:
return t.cat((t.tensor([value_to_prepend]), tensor))
return torch.cat((torch.tensor([value_to_prepend]), tensor))
else:
# input: 2D tensor of shape [seq_len - 1, top_k]
pre = t.full((1, tensor.size()[-1]), value_to_prepend)
return t.cat((pre, tensor), dim=0)
pre = torch.full((1, tensor.size()[-1]), value_to_prepend)
return torch.cat((pre, tensor), dim=0)


def compare_models(
model_a: nn.Module,
model_b: nn.Module,
sample_tok: Int[t.Tensor, "pos"],
sample_tok: Int[torch.Tensor, "pos"],
top_k: int = 3,
) -> ModelComparison:
"""
Expand All @@ -46,7 +46,7 @@ def compare_models(
- The top k token predictions according to model B
- The probabilities of these tokens according to model A
- The probabilities of these tokens according to model B
Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
"""
assert (
model_a.device == model_b.device
Expand All @@ -58,8 +58,8 @@ def compare_models(
next_probs_a, probs_a = get_correct_and_all_probs(model_a, sample_tok)
next_probs_b, probs_b = get_correct_and_all_probs(model_b, sample_tok)

top_k_b = t.topk(probs_b, top_k, dim=-1)
top_k_a_probs = t.gather(probs_a, 1, top_k_b.indices)
top_k_b = torch.topk(probs_b, top_k, dim=-1)
top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)

next_probs_a = _pad_start(next_probs_a)
next_probs_b = _pad_start(next_probs_b)
Expand Down

0 comments on commit 1242506

Please sign in to comment.