Skip to content

Commit

Permalink
Fix log/probs bug in compare_models
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Feb 13, 2024
1 parent ecc86be commit 1b65f8b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/delphi/eval/compare_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def compare_models(
device = model_a.device
sample_tok = sample_tok.to(device)

probs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
probs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)
logprobs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
logprobs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)

probs_a = torch.exp(logprobs_a)
probs_b = torch.exp(logprobs_b)

top_k_b = torch.topk(probs_b, top_k, dim=-1)
top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)
Expand Down

0 comments on commit 1b65f8b

Please sign in to comment.