diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py index 66c8d430..1a3ca27d 100644 --- a/src/delphi/eval/vis.py +++ b/src/delphi/eval/vis.py @@ -1,3 +1,4 @@ +import typing import uuid import torch as t @@ -78,7 +79,7 @@ def token_to_html( # TODO: basic unit test for visualizations w/ Selenium def vis_sample_prediction_probs( sample_tok: Int[t.Tensor, "pos"], - correct_probs: Float[t.Tensor, "next_pos"], + correct_probs: Float[t.Tensor, "pos"], top_k_probs: t.return_types.topk, tokenizer: PreTrainedTokenizerBase, ): @@ -92,18 +93,18 @@ def vis_sample_prediction_probs( hover_div_id = f"hover_info_{unique_id}" for i in range(sample_tok.shape[0]): - tok = sample_tok[i].item() + tok = typing.cast(int, sample_tok[i].item()) data = {} if i > 0: correct_prob = correct_probs[i - 1].item() - # ignore type error (tok is a 'Number' by type, but an int in practice) - data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer) # type: ignore + data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer) top_k_probs_tokens = top_k_probs.indices[i - 1] top_k_probs_values = top_k_probs.values[i - 1] for j in range(top_k_probs_tokens.shape[0]): top_tok = top_k_probs_tokens[j].item() + top_tok = typing.cast(int, top_tok) top_prob = top_k_probs_values[j].item() - data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer) # type: ignore + data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer) token_htmls.append( token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(