diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py index d95c57f2..faab8a02 100644 --- a/src/delphi/eval/calc_model_group_stats.py +++ b/src/delphi/eval/calc_model_group_stats.py @@ -1,11 +1,12 @@ import numpy as np +import torch from datasets import Dataset from jaxtyping import Float def calc_model_group_stats( tokenized_corpus_dataset: Dataset, - logprobs_by_dataset: dict[str, list[list[float]]], + logprobs_by_dataset: dict[str, torch.Tensor], selected_tokens: list[int], ) -> dict[str, dict[str, float]]: """ @@ -31,7 +32,7 @@ def calc_model_group_stats( for ix_token, token in enumerate(tokens): if ix_token == 0: # skip the first token, which isn't predicted continue - logprob = document_lps[ix_token] + logprob = document_lps[ix_token].item() if token in selected_tokens: model_logprobs.append(logprob)