From 36bccb2a778ddac820e6a4bfef500a52204b944a Mon Sep 17 00:00:00 2001 From: Jett Date: Mon, 5 Feb 2024 22:15:41 -0800 Subject: [PATCH] logprobs utils & tests (#24) --- src/delphi/eval/utils.py | 27 +++++++++++++++++++++++++ tests/eval/test_utils.py | 43 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 src/delphi/eval/utils.py create mode 100644 tests/eval/test_utils.py diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py new file mode 100644 index 00000000..1ad7c256 --- /dev/null +++ b/src/delphi/eval/utils.py @@ -0,0 +1,27 @@ +from collections.abc import Callable + +import torch +from jaxtyping import Float, Int + + +def get_all_logprobs( + model: Callable, input_ids: Int[torch.Tensor, "batch seq"] +) -> Float[torch.Tensor, "batch seq vocab"]: + # batch, seq, vocab + logits = model(input_ids).logits + return torch.log_softmax(logits, dim=-1) + + +def gather_logprobs( + logprobs: Float[torch.Tensor, "batch seq vocab"], + tokens: Int[torch.Tensor, "batch seq"], +) -> Float[torch.Tensor, "batch seq"]: + return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1) + + +def get_next_logprobs( + model: Callable, input_ids: Int[torch.Tensor, "batch seq"] +) -> Float[torch.Tensor, "batch shorter_seq"]: + logprobs = get_all_logprobs(model, input_ids[:, :-1]) + next_tokens = input_ids[:, 1:] + return gather_logprobs(logprobs, next_tokens) diff --git a/tests/eval/test_utils.py b/tests/eval/test_utils.py new file mode 100644 index 00000000..cefae455 --- /dev/null +++ b/tests/eval/test_utils.py @@ -0,0 +1,43 @@ +import torch + +from delphi.eval.utils import gather_logprobs + + +def test_gather_logprobs(): + # vocab size = 3 + logprobs = torch.tensor( + [ + # batch 0 + [ + # seq 0 + [0.00, 0.01, 0.02], + # seq 1 + [0.10, 0.11, 0.12], + ], + # batch 1 + [ + # seq 0 + [1.00, 1.01, 1.02], + # seq 1 + [1.10, 1.11, 1.12], + ], + ] + ) + tokens = torch.tensor( + [ + # batch 0 + [0, 2], + # batch 1 + [1, 2], + ] + ) + expected_output = torch.tensor( + [ + # batch 0 + [0.00, 0.12], + # batch 1 + [1.01, 1.12], + ] + ) + result = gather_logprobs(logprobs, tokens) + assert torch.allclose(result, expected_output)