diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index ad1f4806..cfff212c 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -54,7 +54,7 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs): self.downloaded_path = local_path self.logprobs = kwargs.get('logprobs', None) - self.top_logprobs = kwargs.get('top_logprobs', 3 if self.logprobs else None) + self.top_logprobs = kwargs.get('top_logprobs', None) if self.downloaded_path is None: self.downloaded_path, run_type = pull_model(self.model_path) @@ -331,17 +331,11 @@ def run_streamlit(self, model_path: str): action="store_true", help="Run the inference in Streamlit UI", ) - parser.add_argument( - "-lps", - "--logprobs", - action="store_true", - help="Whether to return log probabilities of the output tokens", - ) parser.add_argument( "-tlps", "--top_logprobs", type=int, - default=None, + default=None, # -tlps 5 help="Number of most likely tokens to return at each token position", ) args = parser.parse_args() @@ -349,10 +343,6 @@ def run_streamlit(self, model_path: str): model_path = kwargs.pop("model_path") stop_words = kwargs.pop("stop_words", []) - # set top_logprobs to 3 if logprobs is True and top_logprobs is not specified: - if kwargs.get("logprobs") and kwargs.get("top_logprobs") is None: - kwargs["top_logprobs"] = 3 - inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs) if args.streamlit: inference.run_streamlit(model_path)