From 49615ecb5fec7d3809f086db8c360c8abd15c3b5 Mon Sep 17 00:00:00 2001 From: qmeng222 Date: Thu, 5 Sep 2024 19:20:21 +0000 Subject: [PATCH] ensure logprobs is being passed as a boolean throughout the call chain --- nexa/gguf/llama/llama.py | 2 +- nexa/gguf/nexa_inference_text.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nexa/gguf/llama/llama.py b/nexa/gguf/llama/llama.py index 10120821..2bf97f4c 100644 --- a/nexa/gguf/llama/llama.py +++ b/nexa/gguf/llama/llama.py @@ -732,7 +732,7 @@ def sample( apply_grammar=grammar is not None, ) - if logprobs: + if logprobs is not None and (top_logprobs is not None and top_logprobs > 0): sampled_logprobs = self.logits_to_logprobs(logits) token_logprob = float(sampled_logprobs[id]) diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index 1ad754d7..ad1f4806 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -332,11 +332,13 @@ def run_streamlit(self, model_path: str): help="Run the inference in Streamlit UI", ) parser.add_argument( + "-lps", "--logprobs", action="store_true", help="Whether to return log probabilities of the output tokens", ) parser.add_argument( + "-tlps", "--top_logprobs", type=int, default=None,