Skip to content

Commit

Permalink
add a run name parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Jun 22, 2024
1 parent 5ff0832 commit f1d2bc6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
21 changes: 13 additions & 8 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from time import time
import requests
from utils.reporting import upload_results
from uuid import uuid4


def mk_vllm_json(prompt, num_beams, logprobs=False, sql_lora_path=None):
Expand Down Expand Up @@ -189,6 +190,7 @@ def run_api_eval(args):
logprobs = args.logprobs
cot_table_alias = args.cot_table_alias
sql_lora_path = args.adapter if args.adapter else None
run_name = args.run_name if args.run_name else None
if sql_lora_path:
print("Using LoRA adapter at:", sql_lora_path)
if logprobs:
Expand Down Expand Up @@ -302,11 +304,14 @@ def run_api_eval(args):
# with open(prompt_file, "r") as f:
# prompt = f.read()

# if args.upload_url is not None:
# upload_results(
# results=results,
# url=args.upload_url,
# runner_type="api_runner",
# prompt=prompt,
# args=args,
# )
if run_name is None:
run_name = uuid4().hex

if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
args=args,
run_name=run_name,
)
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-l", "--logprobs", action="store_true")
parser.add_argument("--upload_url", type=str)
parser.add_argument("--run_name", type=str, required=False)
parser.add_argument(
"-qz", "--quantized", default=False, action=argparse.BooleanOptionalAction
)
Expand Down
14 changes: 2 additions & 12 deletions utils/reporting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import json
import requests
from uuid import uuid4
from datetime import datetime
import os
import hashlib


# get the GPU name this is running on
Expand Down Expand Up @@ -81,6 +79,7 @@ def num_gpus():
def upload_results(
results: list,
url: str,
run_name: str,
runner_type: str,
prompt: str,
args: dict,
Expand All @@ -94,26 +93,17 @@ def upload_results(
# Create a unique id for the request
run_id = uuid4().hex

# Create a unique id for the prompt, based on a hash of the prompt
prompt_id = hashlib.md5(prompt.encode()).hexdigest()

# Create a dictionary with the request id and the results
data = {
"run_id": run_id,
"results": results,
"timestamp": datetime.now().isoformat(),
"runner_type": runner_type,
"prompt": prompt,
"prompt_id": prompt_id,
"model": args.model,
"num_beams": args.num_beams,
"db_type": args.db_type,
"gpu_name": get_gpu_name(),
"gpu_memory": get_gpu_memory(),
"gpu_driver_version": get_gpu_driver_version(),
"gpu_cuda_version": get_gpu_cuda_version(),
"num_gpus": num_gpus(),
"run_args": vars(args),
"run_name": run_name,
}
# Send the data to the server
response = requests.post(url, json=data)
Expand Down

0 comments on commit f1d2bc6

Please sign in to comment.