diff --git a/README.md b/README.md index a72864d..1d739fa 100644 --- a/README.md +++ b/README.md @@ -381,6 +381,11 @@ python -W ignore main.py \ ### Bedrock +Before running this, you would need to export the following environment variables for the boto3 client to work: +- `AWS_ACCESS_KEY_ID` +- `AWS_SECRET_ACCESS_KEY` +- `AWS_DEFAULT_REGION` + ```bash python3 main.py \ -db postgres \ @@ -388,7 +393,6 @@ python3 main.py \ -o results/bedrock_llama_70b_basic.csv results/bedrock_llama_70b_advanced.csv results/bedrock_llama_70b_v1.csv \ -g bedrock \ -f prompts/prompt_cot_postgres.md \ - --cot_table_alias prealias \ -m meta.llama3-70b-instruct-v1:0 \ -c 0 \ -p 10 @@ -405,7 +409,6 @@ python3 main.py \ -o results/together_llama_70b_basic.csv results/together_llama_70b_advanced.csv results/together_llama_70b_v1.csv \ -g together \ -f prompts/prompt_together.json \ - --cot_table_alias prealias \ -m "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \ -c 0 \ -p 10 @@ -437,14 +440,14 @@ You can use the following flags in the command line to change the configurations ### Inference-technique-related parameters -| CLI Flags | Description | -| ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --- | -| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. | -| -b, --num_beams | Indicates the number of beams you want to use for beam search at inference. Only available for `hf_runner`, `vllm_runner`, and `api_runner`. | -| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. | -| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). | -| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. | -| --cot_table_alias | Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`, `prealias` and `pregen`. If using `instruct` or `prealias`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases, while `prealias` would already generate the aliases in the prompt. | | +| CLI Flags | Description | | +| ---------------------- |------------- | --- | +| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. | +| -b, --num_beams | Indicates the number of beams you want to use for beam search at inference. Only available for `hf_runner`, `vllm_runner`, and `api_runner`. | +| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. | +| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). | +| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. | +| --cot_table_alias | (Experimental) Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`. If using `instruct`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases. | ### Execution-related parameters diff --git a/eval/api_runner.py b/eval/api_runner.py index d92a8fe..656641e 100644 --- a/eval/api_runner.py +++ b/eval/api_runner.py @@ -268,6 +268,7 @@ def run_api_eval(args): public_data, args.num_columns, args.shuffle_metadata, + row["table_aliases"], ), axis=1, ) diff --git a/main.py b/main.py index 60fb3c9..1733659 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,9 @@ parser.add_argument("-c", "--num_columns", type=int, default=0) parser.add_argument("-s", "--shuffle_metadata", action="store_true") parser.add_argument("-k", "--k_shot", action="store_true") - parser.add_argument("--cot_table_alias", type=str) + parser.add_argument( + "--cot_table_alias", type=str, choices=["instruct", "pregen", ""], default="" + ) # execution-related parameters parser.add_argument("-o", "--output_file", nargs="+", type=str, required=True) parser.add_argument("-p", "--parallel_threads", type=int, default=5) diff --git a/prompts/prompt_cot.md b/prompts/prompt_cot.md index d339ab6..c52c180 100644 --- a/prompts/prompt_cot.md +++ b/prompts/prompt_cot.md @@ -7,7 +7,7 @@ Generate a {db_type} query to answer this question: `{user_question}` DDL statements: {table_metadata_string} -{cot_instructions}Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|> +{table_aliases}Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I will reflect on the user's request before answering the question. diff --git a/prompts/prompt_cot_postgres.md b/prompts/prompt_cot_postgres.md index e7492be..c388700 100644 --- a/prompts/prompt_cot_postgres.md +++ b/prompts/prompt_cot_postgres.md @@ -6,9 +6,9 @@ Generate a SQL query to answer this question: `{user_question}` {instructions} DDL statements: {table_metadata_string} -{join_hints} +{join_str} -{cot_instructions}Generate a valid SQL query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|> +{table_aliases}Generate a valid SQL query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I will reflect on the user's request before answering the question. diff --git a/prompts/prompt_cot_sqlite.md b/prompts/prompt_cot_sqlite.md index 8edeba2..c37c89d 100644 --- a/prompts/prompt_cot_sqlite.md +++ b/prompts/prompt_cot_sqlite.md @@ -7,7 +7,7 @@ Generate a {db_type} query to answer this question: `{user_question}` DDL statements: {table_metadata_string} -{cot_instructions}<|eot_id|><|start_header_id|>assistant<|end_header_id|> +{table_aliases}<|eot_id|><|start_header_id|>assistant<|end_header_id|> I was asked to generate a SQL query for this question: `{user_question}` diff --git a/utils/gen_prompt.py b/utils/gen_prompt.py index a2cd02e..27d692c 100644 --- a/utils/gen_prompt.py +++ b/utils/gen_prompt.py @@ -287,10 +287,8 @@ def generate_prompt( query_1=query_1, cot_instructions=cot_instructions, instruction_reflections=instruction_reflections, - join_hints=join_str, + table_aliases=table_aliases, + join_str=join_str, pruned_join_hints=pruned_join_str, ) - if cot_pregen: - table_aliases = generate_aliases(table_names) - prompt = prompt + table_aliases return prompt diff --git a/utils/questions.py b/utils/questions.py index 4f240ff..414a89d 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -93,6 +93,11 @@ def prepare_questions_df( else: question_query_df["table_metadata_string"] = "" + # get table_aliases + question_query_df["table_aliases"] = question_query_df["db_name"].apply( + get_table_aliases + ) + # get prev_invalid_sql if applicable if "prev_invalid_sql" in question_query_df.columns: question_query_df["prev_invalid_sql"] = question_query_df[ @@ -127,25 +132,14 @@ def prepare_questions_df( else: question_query_df["query_1"] = "" - # add all cot instructions to the `cot_instructions` column + # add all cot instructions to the respective columns + question_query_df["cot_instructions"] = "" + question_query_df["cot_pregen"] = False if cot_table_alias == "instruct": question_query_df["cot_instructions"] = ( "List the table aliases for each table as comments, starting with the most relevant tables to the question." ) - elif cot_table_alias == "prealias": - question_query_df["cot_instructions"] = question_query_df["db_name"].apply( - get_table_aliases - ) - question_query_df["table_aliases"] = question_query_df["db_name"].apply( - get_table_aliases - ) - else: - question_query_df["cot_instructions"] = "" - question_query_df["table_aliases"] = "" - - if cot_table_alias == "pregen": + elif cot_table_alias == "pregen": question_query_df["cot_pregen"] = True - else: - question_query_df["cot_pregen"] = False return question_query_df