forked from defog-ai/sql-eval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
143 lines (123 loc) · 5.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import os
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# data-related parameters
parser.add_argument("-q", "--questions_file", nargs="+", type=str, required=True)
parser.add_argument("-n", "--num_questions", type=int, default=None)
parser.add_argument("-db", "--db_type", type=str, required=True)
parser.add_argument("-d", "--use_private_data", action="store_true")
parser.add_argument("-dp", "--decimal_points", type=int, default=None)
# model-related parameters
parser.add_argument("-g", "--model_type", type=str, required=True)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-a", "--adapter", type=str) # path to adapter
parser.add_argument(
"-an", "--adapter_name", type=str, default=None
) # only for use with production server
parser.add_argument("--api_url", type=str)
parser.add_argument("--api_type", type=str)
# inference-technique-related parameters
parser.add_argument("-f", "--prompt_file", nargs="+", type=str, required=True)
parser.add_argument("-b", "--num_beams", type=int, default=1)
parser.add_argument(
"-bs", "--batch_size", type=int, default=4
) # batch size, only relevant for the hf runner
parser.add_argument("-c", "--num_columns", type=int, default=0)
parser.add_argument("-s", "--shuffle_metadata", action="store_true")
parser.add_argument("-k", "--k_shot", action="store_true")
parser.add_argument(
"--cot_table_alias", type=str, choices=["instruct", "pregen", ""], default=""
)
# execution-related parameters
parser.add_argument("-o", "--output_file", nargs="+", type=str, required=True)
parser.add_argument("-p", "--parallel_threads", type=int, default=5)
parser.add_argument("-t", "--timeout_gen", type=float, default=30.0)
parser.add_argument("-u", "--timeout_exec", type=float, default=10.0)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-l", "--logprobs", action="store_true")
parser.add_argument("--upload_url", type=str)
parser.add_argument("--run_name", type=str, required=False)
parser.add_argument(
"-qz", "--quantized", default=False, action=argparse.BooleanOptionalAction
)
args = parser.parse_args()
# if questions_file is None, set it to the default questions file for the given db_type
if args.questions_file is None:
args.questions_file = f"data/questions_gen_{args.db_type}.csv"
# check that questions_file matches db_type
for questions_file in args.questions_file:
if args.db_type not in questions_file and questions_file != "data/idk.csv":
print(
f"WARNING: Check that questions_file {questions_file} is compatible with db_type {args.db_type}"
)
if args.upload_url is None:
args.upload_url = os.environ.get("SQL_EVAL_UPLOAD_URL")
# check args
# check that either args.questions_file > 1 and args.prompt_file = 1 or vice versa
if (
len(args.questions_file) > 1
and len(args.prompt_file) == 1
and len(args.output_file) > 1
):
args.prompt_file = args.prompt_file * len(args.questions_file)
elif (
len(args.questions_file) == 1
and len(args.prompt_file) > 1
and len(args.output_file) > 1
):
args.questions_file = args.questions_file * len(args.prompt_file)
if not (len(args.questions_file) == len(args.prompt_file) == len(args.output_file)):
raise ValueError(
"If args.output_file > 1, then at least 1 of args.prompt_file or args.questions_file must be > 1 and match lengths."
f"Obtained lengths: args.questions_file={len(args.questions_file)}, args.prompt_file={len(args.prompt_file)}, args.output_file={len(args.output_file)}"
)
if args.model_type == "oa":
from eval.openai_runner import run_openai_eval
if args.model is None:
args.model = "gpt-3.5-turbo-0613"
run_openai_eval(args)
elif args.model_type == "anthropic":
from eval.anthropic_runner import run_anthropic_eval
if args.model is None:
args.model = "claude-2"
run_anthropic_eval(args)
elif args.model_type == "vllm":
import platform
if platform.system() == "Darwin":
raise ValueError(
"vLLM is not supported on macOS. Please run on another OS supporting CUDA."
)
from eval.vllm_runner import run_vllm_eval
run_vllm_eval(args)
elif args.model_type == "hf":
from eval.hf_runner import run_hf_eval
run_hf_eval(args)
elif args.model_type == "api":
assert args.api_url is not None, "api_url must be provided for api model"
assert args.api_type is not None, "api_type must be provided for api model"
assert args.api_type in ["vllm", "tgi"], "api_type must be one of 'vllm', 'tgi'"
from eval.api_runner import run_api_eval
run_api_eval(args)
elif args.model_type == "llama_cpp":
from eval.llama_cpp_runner import run_llama_cpp_eval
run_llama_cpp_eval(args)
elif args.model_type == "mlx":
from eval.mlx_runner import run_mlx_eval
run_mlx_eval(args)
elif args.model_type == "gemini":
from eval.gemini_runner import run_gemini_eval
run_gemini_eval(args)
elif args.model_type == "mistral":
from eval.mistral_runner import run_mistral_eval
run_mistral_eval(args)
elif args.model_type == "bedrock":
from eval.bedrock_runner import run_bedrock_eval
run_bedrock_eval(args)
elif args.model_type == "together":
from eval.together_runner import run_together_eval
run_together_eval(args)
else:
raise ValueError(
f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf', 'anthropic', 'vllm', 'api', 'llama_cpp', 'mlx', 'gemini', 'mistral'"
)