Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapter_name arg #199

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from utils.reporting import upload_results


def mk_vllm_json(prompt, num_beams, logprobs=False, sql_lora_path=None):
def mk_vllm_json(
prompt, num_beams, logprobs=False, sql_lora_path=None, sql_lora_name=None
):
payload = {
"prompt": prompt,
"n": 1,
Expand All @@ -25,6 +27,7 @@ def mk_vllm_json(prompt, num_beams, logprobs=False, sql_lora_path=None):
"max_tokens": 4000,
"seed": 42,
"sql_lora_path": sql_lora_path,
"sql_lora_name": sql_lora_name,
}
if logprobs:
payload["logprobs"] = 2
Expand Down Expand Up @@ -53,12 +56,15 @@ def process_row(
decimal_points: int,
logprobs: bool = False,
sql_lora_path: Optional[str] = None,
sql_lora_name: Optional[str] = None,
):
start_time = time()
if api_type == "tgi":
json_data = mk_tgi_json(row["prompt"], num_beams)
elif api_type == "vllm":
json_data = mk_vllm_json(row["prompt"], num_beams, logprobs, sql_lora_path)
json_data = mk_vllm_json(
row["prompt"], num_beams, logprobs, sql_lora_path, sql_lora_name
)
else:
# add any custom JSON data here, e.g. for a custom API
json_data = {
Expand Down Expand Up @@ -189,6 +195,7 @@ def run_api_eval(args):
logprobs = args.logprobs
cot_table_alias = args.cot_table_alias
sql_lora_path = args.adapter if args.adapter else None
sql_lora_name = args.adapter_name if args.adapter_name else None
run_name = args.run_name if args.run_name else None
if sql_lora_path:
print("Using LoRA adapter at:", sql_lora_path)
Expand Down Expand Up @@ -258,6 +265,7 @@ def run_api_eval(args):
decimal_points,
logprobs,
sql_lora_path,
sql_lora_name,
)
)

Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# model-related parameters
parser.add_argument("-g", "--model_type", type=str, required=True)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-a", "--adapter", type=str)
parser.add_argument("-a", "--adapter", type=str) # path to adapter
parser.add_argument(
"-an", "--adapter_name", type=str, default=None
) # only for use with production server
parser.add_argument("--api_url", type=str)
parser.add_argument("--api_type", type=str)
# inference-technique-related parameters
Expand Down
Loading