Skip to content

Commit

Permalink
ensure logprobs is being passed as a boolean throughout the call chain
Browse files Browse the repository at this point in the history
  • Loading branch information
qmeng222 committed Sep 5, 2024
1 parent 9ba5684 commit 49615ec
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion nexa/gguf/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def sample(
apply_grammar=grammar is not None,
)

if logprobs:
if logprobs is not None and (top_logprobs is not None and top_logprobs > 0):
sampled_logprobs = self.logits_to_logprobs(logits)
token_logprob = float(sampled_logprobs[id])

Expand Down
2 changes: 2 additions & 0 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,13 @@ def run_streamlit(self, model_path: str):
help="Run the inference in Streamlit UI",
)
parser.add_argument(
"-lps",
"--logprobs",
action="store_true",
help="Whether to return log probabilities of the output tokens",
)
parser.add_argument(
"-tlps",
"--top_logprobs",
type=int,
default=None,
Expand Down

0 comments on commit 49615ec

Please sign in to comment.