From 7799601fa8f571a9752540062b84f572209365d6 Mon Sep 17 00:00:00 2001 From: JaiDhyani Date: Wed, 24 Apr 2024 10:09:31 -0700 Subject: [PATCH] beartype fix --- src/delphi/eval/calc_model_group_stats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py index d95c57f2..faab8a02 100644 --- a/src/delphi/eval/calc_model_group_stats.py +++ b/src/delphi/eval/calc_model_group_stats.py @@ -1,11 +1,12 @@ import numpy as np +import torch from datasets import Dataset from jaxtyping import Float def calc_model_group_stats( tokenized_corpus_dataset: Dataset, - logprobs_by_dataset: dict[str, list[list[float]]], + logprobs_by_dataset: dict[str, torch.Tensor], selected_tokens: list[int], ) -> dict[str, dict[str, float]]: """ @@ -31,7 +32,7 @@ def calc_model_group_stats( for ix_token, token in enumerate(tokens): if ix_token == 0: # skip the first token, which isn't predicted continue - logprob = document_lps[ix_token] + logprob = document_lps[ix_token].item() if token in selected_tokens: model_logprobs.append(logprob)