Skip to content

Commit

Permalink
beartype fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Apr 24, 2024
1 parent 75ea6e7 commit 7799601
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/delphi/eval/calc_model_group_stats.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
import torch
from datasets import Dataset
from jaxtyping import Float


def calc_model_group_stats(
tokenized_corpus_dataset: Dataset,
logprobs_by_dataset: dict[str, list[list[float]]],
logprobs_by_dataset: dict[str, torch.Tensor],
selected_tokens: list[int],
) -> dict[str, dict[str, float]]:
"""
Expand All @@ -31,7 +32,7 @@ def calc_model_group_stats(
for ix_token, token in enumerate(tokens):
if ix_token == 0: # skip the first token, which isn't predicted
continue
logprob = document_lps[ix_token]
logprob = document_lps[ix_token].item()
if token in selected_tokens:
model_logprobs.append(logprob)

Expand Down

0 comments on commit 7799601

Please sign in to comment.