From cad78483756fce5a38c07239d367c9be078d7580 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:57:06 +0200 Subject: [PATCH] HumanEval: Rename new args to match other scripts --- eval/humaneval.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/eval/humaneval.py b/eval/humaneval.py index 4319019d..9c2a3b26 100644 --- a/eval/humaneval.py +++ b/eval/humaneval.py @@ -1,8 +1,6 @@ from __future__ import annotations -import os -import sys - +import os, sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from human_eval.data import write_jsonl, read_problems from exllamav2 import model_init @@ -25,9 +23,9 @@ parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating") parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling") parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6", default = 0.6) -parser.add_argument("--top_k", type = int, help = "Top-k sampling, default: 50", default = 50) -parser.add_argument("--top_p", type = float, help = "Top-p sampling, default: 0.6", default = 0.6) -parser.add_argument("-trp", "--token_repetition_penalty", type = float, help = "Token repetition penalty, default: 1.0", default = 1.0) +parser.add_argument("-topk", "--top_k", type = int, help = "Top-k sampling, default: 50", default = 50) +parser.add_argument("-topp", "--top_p", type = float, help = "Top-p sampling, default: 0.6", default = 0.6) +parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Token repetition penalty, default: 1.0", default = 1.0) model_init.add_args(parser) args = parser.parse_args() @@ -124,10 +122,10 @@ ) gen_settings = ExLlamaV2Sampler.Settings( - token_repetition_penalty=args.token_repetition_penalty, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p + token_repetition_penalty = args.repetition_penalty, + temperature = args.temperature, + top_k = args.top_k, + top_p = args.top_p ) # Get problems