Skip to content

Commit

Permalink
refactor: Convert PostgreSQL DDL to dialect-specific DDL in query gen…
Browse files Browse the repository at this point in the history
…erators
  • Loading branch information
rishsriv committed Jul 9, 2024
1 parent 0f940e8 commit 54634a7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 45 deletions.
30 changes: 7 additions & 23 deletions query_generators/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,11 @@
from query_generators.query_generator import QueryGenerator
from utils.pruning import prune_metadata_str
from utils.gen_prompt import to_prompt_schema
from utils.dialects import (
ddl_to_bigquery,
ddl_to_mysql,
ddl_to_sqlite,
ddl_to_tsql,
)
from utils.dialects import convert_postgres_ddl_to_dialect

anthropic = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))


def convert_ddl(postgres_ddl: str, to_dialect: str, db_name: str):
if to_dialect == "postgres":
return postgres_ddl
elif to_dialect == "bigquery":
new_ddl, _ = ddl_to_bigquery(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "mysql":
new_ddl, _ = ddl_to_mysql(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "sqlite":
new_ddl, _ = ddl_to_sqlite(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "tsql":
new_ddl, _ = ddl_to_tsql(postgres_ddl, "postgres", db_name, 42)
else:
raise ValueError(f"Unsupported dialect {to_dialect}")
return new_ddl


class AnthropicQueryGenerator(QueryGenerator):
"""
Query generator that uses Anthropic's models
Expand Down Expand Up @@ -135,11 +114,16 @@ def generate_query(
columns_to_keep,
shuffle,
)
pruned_metadata_ddl = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_ddl,
to_dialect=self.db_type,
db_name=self.db_name,
)
pruned_metadata_str = pruned_metadata_ddl + join_str
elif columns_to_keep == 0:
md = dbs[self.db_name]["table_metadata"]
pruned_metadata_str = to_prompt_schema(md, shuffle)
pruned_metadata_str = convert_ddl(
pruned_metadata_str = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_str,
to_dialect=self.db_type,
db_name=self.db_name,
Expand Down
25 changes: 3 additions & 22 deletions query_generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,12 @@
from utils.pruning import prune_metadata_str
from utils.gen_prompt import to_prompt_schema
from utils.dialects import (
ddl_to_bigquery,
ddl_to_mysql,
ddl_to_sqlite,
ddl_to_tsql,
convert_postgres_ddl_to_dialect,
)

openai = OpenAI()


def convert_ddl(postgres_ddl: str, to_dialect: str, db_name: str):
if to_dialect == "postgres":
return postgres_ddl
elif to_dialect == "bigquery":
new_ddl, _ = ddl_to_bigquery(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "mysql":
new_ddl, _ = ddl_to_mysql(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "sqlite":
new_ddl, _ = ddl_to_sqlite(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "tsql":
new_ddl, _ = ddl_to_tsql(postgres_ddl, "postgres", db_name, 42)
else:
raise ValueError(f"Unsupported dialect {to_dialect}")
return new_ddl


class OpenAIQueryGenerator(QueryGenerator):
"""
Query generator that uses OpenAI's models
Expand Down Expand Up @@ -168,7 +149,7 @@ def generate_query(
columns_to_keep,
shuffle,
)
table_metadata_ddl = convert_ddl(
table_metadata_ddl = convert_postgres_ddl_to_dialect(
postgres_ddl=table_metadata_ddl,
to_dialect=self.db_type,
db_name=self.db_name,
Expand All @@ -177,7 +158,7 @@ def generate_query(
elif columns_to_keep == 0:
md = dbs[self.db_name]["table_metadata"]
table_metadata_ddl = to_prompt_schema(md, shuffle)
table_metadata_ddl = convert_ddl(
table_metadata_ddl = convert_postgres_ddl_to_dialect(
postgres_ddl=table_metadata_ddl,
to_dialect=self.db_type,
db_name=self.db_name,
Expand Down
22 changes: 22 additions & 0 deletions utils/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,3 +1131,25 @@ def test_valid_md_tsql_concurr(df, sql_list_col, table_metadata_col):
results[index] = future.result()

return results


### General conversion function
def convert_postgres_ddl_to_dialect(postgres_ddl: str, to_dialect: str, db_name: str):
"""
This function converts a ddl from postgres to another dialect.
We have a separate function for this since the default defog_data DDLS
are for Postgres, and using this means less code when converting.
"""
if to_dialect == "postgres":
return postgres_ddl
elif to_dialect == "bigquery":
new_ddl, _ = ddl_to_bigquery(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "mysql":
new_ddl, _ = ddl_to_mysql(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "sqlite":
new_ddl, _ = ddl_to_sqlite(postgres_ddl, "postgres", db_name, 42)
elif to_dialect == "tsql":
new_ddl, _ = ddl_to_tsql(postgres_ddl, "postgres", db_name, 42)
else:
raise ValueError(f"Unsupported dialect {to_dialect}")
return new_ddl

0 comments on commit 54634a7

Please sign in to comment.