Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up cot_instructions placeholder #216

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,18 @@ python -W ignore main.py \

### Bedrock

Before running this, you would need to export the following environment variables for the boto3 client to work:
- `AWS_ACCESS_KEY_ID`
- `AWS_SECRET_ACCESS_KEY`
- `AWS_DEFAULT_REGION`

```bash
python3 main.py \
-db postgres \
-q data/instruct_basic_postgres.csv data/instruct_advanced_postgres.csv data/questions_gen_postgres.csv \
-o results/bedrock_llama_70b_basic.csv results/bedrock_llama_70b_advanced.csv results/bedrock_llama_70b_v1.csv \
-g bedrock \
-f prompts/prompt_cot_postgres.md \
--cot_table_alias prealias \
-m meta.llama3-70b-instruct-v1:0 \
-c 0 \
-p 10
Expand All @@ -405,7 +409,6 @@ python3 main.py \
-o results/together_llama_70b_basic.csv results/together_llama_70b_advanced.csv results/together_llama_70b_v1.csv \
-g together \
-f prompts/prompt_together.json \
--cot_table_alias prealias \
-m "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \
-c 0 \
-p 10
Expand Down Expand Up @@ -437,14 +440,14 @@ You can use the following flags in the command line to change the configurations

### Inference-technique-related parameters

| CLI Flags | Description |
| ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --- |
| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. |
| -b, --num_beams | Indicates the number of beams you want to use for beam search at inference. Only available for `hf_runner`, `vllm_runner`, and `api_runner`. |
| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. |
| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). |
| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. |
| --cot_table_alias | Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`, `prealias` and `pregen`. If using `instruct` or `prealias`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases, while `prealias` would already generate the aliases in the prompt. | |
| CLI Flags | Description | |
| ---------------------- |------------- | --- |
| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. |
| -b, --num_beams | Indicates the number of beams you want to use for beam search at inference. Only available for `hf_runner`, `vllm_runner`, and `api_runner`. |
| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. |
| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). |
| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. |
| --cot_table_alias | (Experimental) Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`. If using `instruct`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases. |

### Execution-related parameters

Expand Down
1 change: 1 addition & 0 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def run_api_eval(args):
public_data,
args.num_columns,
args.shuffle_metadata,
row["table_aliases"],
),
axis=1,
)
Expand Down
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
parser.add_argument("-c", "--num_columns", type=int, default=0)
parser.add_argument("-s", "--shuffle_metadata", action="store_true")
parser.add_argument("-k", "--k_shot", action="store_true")
parser.add_argument("--cot_table_alias", type=str)
parser.add_argument(
"--cot_table_alias", type=str, choices=["instruct", "pregen", ""], default=""
)
# execution-related parameters
parser.add_argument("-o", "--output_file", nargs="+", type=str, required=True)
parser.add_argument("-p", "--parallel_threads", type=int, default=5)
Expand Down
2 changes: 1 addition & 1 deletion prompts/prompt_cot.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Generate a {db_type} query to answer this question: `{user_question}`
DDL statements:
{table_metadata_string}

{cot_instructions}Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{table_aliases}Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I will reflect on the user's request before answering the question.

Expand Down
4 changes: 2 additions & 2 deletions prompts/prompt_cot_postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Generate a SQL query to answer this question: `{user_question}`
{instructions}
DDL statements:
{table_metadata_string}
{join_hints}
{join_str}

{cot_instructions}Generate a valid SQL query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{table_aliases}Generate a valid SQL query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I will reflect on the user's request before answering the question.

Expand Down
2 changes: 1 addition & 1 deletion prompts/prompt_cot_sqlite.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Generate a {db_type} query to answer this question: `{user_question}`
DDL statements:
{table_metadata_string}

{cot_instructions}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{table_aliases}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I was asked to generate a SQL query for this question: `{user_question}`

Expand Down
6 changes: 2 additions & 4 deletions utils/gen_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,8 @@ def generate_prompt(
query_1=query_1,
cot_instructions=cot_instructions,
instruction_reflections=instruction_reflections,
join_hints=join_str,
table_aliases=table_aliases,
join_str=join_str,
pruned_join_hints=pruned_join_str,
)
if cot_pregen:
table_aliases = generate_aliases(table_names)
prompt = prompt + table_aliases
return prompt
24 changes: 9 additions & 15 deletions utils/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def prepare_questions_df(
else:
question_query_df["table_metadata_string"] = ""

# get table_aliases
question_query_df["table_aliases"] = question_query_df["db_name"].apply(
get_table_aliases
)

# get prev_invalid_sql if applicable
if "prev_invalid_sql" in question_query_df.columns:
question_query_df["prev_invalid_sql"] = question_query_df[
Expand Down Expand Up @@ -127,25 +132,14 @@ def prepare_questions_df(
else:
question_query_df["query_1"] = ""

# add all cot instructions to the `cot_instructions` column
# add all cot instructions to the respective columns
question_query_df["cot_instructions"] = ""
question_query_df["cot_pregen"] = False
if cot_table_alias == "instruct":
question_query_df["cot_instructions"] = (
"List the table aliases for each table as comments, starting with the most relevant tables to the question."
)
elif cot_table_alias == "prealias":
question_query_df["cot_instructions"] = question_query_df["db_name"].apply(
get_table_aliases
)
question_query_df["table_aliases"] = question_query_df["db_name"].apply(
get_table_aliases
)
else:
question_query_df["cot_instructions"] = ""
question_query_df["table_aliases"] = ""

if cot_table_alias == "pregen":
elif cot_table_alias == "pregen":
question_query_df["cot_pregen"] = True
else:
question_query_df["cot_pregen"] = False

return question_query_df
Loading