Skip to content

Commit

Permalink
remove redundant logprobs CLI argument
Browse files Browse the repository at this point in the history
  • Loading branch information
qmeng222 committed Sep 5, 2024
1 parent 1c77636 commit a37f501
Showing 1 changed file with 2 additions and 12 deletions.
14 changes: 2 additions & 12 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs):
self.downloaded_path = local_path

self.logprobs = kwargs.get('logprobs', None)
self.top_logprobs = kwargs.get('top_logprobs', 3 if self.logprobs else None)
self.top_logprobs = kwargs.get('top_logprobs', None)

if self.downloaded_path is None:
self.downloaded_path, run_type = pull_model(self.model_path)
Expand Down Expand Up @@ -331,28 +331,18 @@ def run_streamlit(self, model_path: str):
action="store_true",
help="Run the inference in Streamlit UI",
)
parser.add_argument(
"-lps",
"--logprobs",
action="store_true",
help="Whether to return log probabilities of the output tokens",
)
parser.add_argument(
"-tlps",
"--top_logprobs",
type=int,
default=None,
default=None, # -tlps 5
help="Number of most likely tokens to return at each token position",
)
args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
stop_words = kwargs.pop("stop_words", [])

# set top_logprobs to 3 if logprobs is True and top_logprobs is not specified:
if kwargs.get("logprobs") and kwargs.get("top_logprobs") is None:
kwargs["top_logprobs"] = 3

inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs)
if args.streamlit:
inference.run_streamlit(model_path)
Expand Down

0 comments on commit a37f501

Please sign in to comment.