diff --git a/requirements.txt b/requirements.txt index a7e8ff1..d654258 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,7 @@ Flask==2.3.2 Flask-Cors==4.0.0 Flask-RESTful==0.3.10 requests==2.31.0 -tiktoken==0.4.0 -psycopg2-binary==2.9.7 # you can also install from source if it works +tiktoken>=0.4.0 pglast==5.3 -litellm==1.34.34 -platformdirs>=4.0.0 \ No newline at end of file +litellm>=1.34.34 +platformdirs>=4.0.0 diff --git a/setup.py b/setup.py index e943fd0..22cf231 100644 --- a/setup.py +++ b/setup.py @@ -13,22 +13,21 @@ # Define your dependencies install_requires = [ - 'Jinja2==3.1.2', - 'Flask==2.3.2', - 'Flask-Cors==4.0.0', - 'Flask-RESTful==0.3.10', - 'requests==2.31.0', - 'tiktoken==0.4.0', - 'psycopg2-binary==2.9.7', - 'pglast==5.3', - 'litellm==1.34.34', - 'platformdirs>=4.0.0', - 'sqlparse~=0.5.0' + "Jinja2==3.1.2", + "Flask==2.3.2", + "Flask-Cors==4.0.0", + "Flask-RESTful==0.3.10", + "requests==2.31.0", + "tiktoken>=0.4.0", + "pglast>=6.10", + "litellm>=1.34.34", + "platformdirs>=4.0.0", + "sqlparse~=0.5.0", ] install_dev_requires = [ - 'spacy==3.6.0', - 'FlagEmbedding~=1.2.5', + "spacy==3.6.0", + "FlagEmbedding~=1.2.5", ] # Additional package information @@ -46,19 +45,15 @@ name=name, version=version, description=description, - long_description=open('README.md').read(), - long_description_content_type='text/markdown', + long_description=open("README.md").read(), + long_description_content_type="text/markdown", author=author, author_email=author_email, packages=packages, package_dir={"": "src"}, install_requires=install_requires, - extra_requires={ - "dev": install_dev_requires - }, + extra_requires={"dev": install_dev_requires}, url=url, classifiers=classifiers, - package_data={ - "": ["*.prompt"] - } -) \ No newline at end of file + package_data={"": ["*.prompt"]}, +) diff --git a/src/suql/free_text_fcns_server.py b/src/suql/free_text_fcns_server.py index 71b59b0..7782d0f 100644 --- a/src/suql/free_text_fcns_server.py +++ b/src/suql/free_text_fcns_server.py @@ -22,20 +22,21 @@ def _answer( source, query, - type_prompt = None, + type_prompt=None, k=5, max_input_token=10000, - engine="gpt-3.5-turbo-0125" + engine="gpt-3.5-turbo-0125", + api_base=None, + api_version=None, ): from suql.prompt_continuation import llm_generate + if not source: return {"result": "no information"} text_res = [] if isinstance(source, list): - documents = compute_top_similarity_documents( - source, query, top=k - ) + documents = compute_top_similarity_documents(source, query, top=k) for i in documents: if num_tokens_from_string("\n".join(text_res + [i])) < max_input_token: text_res.append(i) @@ -63,11 +64,20 @@ def _answer( temperature=0.0, stop_tokens=[], postprocess=False, + api_base=api_base, + api_version=api_version, ) return {"result": continuation} + def start_free_text_fncs_server( - host="127.0.0.1", port=8500, k=5, max_input_token=3800, engine="gpt-4o-mini" + host="127.0.0.1", + port=8500, + k=5, + max_input_token=3800, + engine="gpt-4o-mini", + api_base=None, + api_version=None, ): """ Set up a free text functions server for the free text @@ -115,11 +125,12 @@ def answer(): data["text"], data["question"], type_prompt=data["type_prompt"] if "type_prompt" in data else None, - k = k, - max_input_token = max_input_token, - engine = engine + k=k, + max_input_token=max_input_token, + engine=engine, + api_base=api_base, + api_version=api_version, ) - @app.route("/summary", methods=["POST"]) def summary(): @@ -166,6 +177,8 @@ def summary(): temperature=0.0, stop_tokens=["\n"], postprocess=False, + api_base=api_base, + api_version=api_version, ) res = {"result": continuation} diff --git a/src/suql/postgresql_connection.py b/src/suql/postgresql_connection.py index b080c17..8a77b4d 100644 --- a/src/suql/postgresql_connection.py +++ b/src/suql/postgresql_connection.py @@ -12,7 +12,9 @@ def execute_sql( data=None, commit_in_lieu_fetch=False, no_print=False, - unprotected=False + unprotected=False, + host="127.0.0.1", + port="5432", ): start_time = time.time() @@ -21,7 +23,7 @@ def execute_sql( dbname=database, user=user, host="/var/run/postgresql", - port="5432", + port=port, options="-c statement_timeout=30000 -c client_encoding=UTF8", ) else: @@ -29,8 +31,8 @@ def execute_sql( database=database, user=user, password=password, - host="127.0.0.1", - port="5432", + host=host, + port=port, options="-c statement_timeout=30000 -c client_encoding=UTF8", ) @@ -57,7 +59,7 @@ def sql_unprotected(): else: results = cursor.fetchall() column_names = [desc[0] for desc in cursor.description] - + return results, column_names try: @@ -85,14 +87,16 @@ def execute_sql_with_column_info( user="select_user", password="select_user", unprotected=False, + host="127.0.0.1", + port="5432", ): # Establish a connection to the PostgreSQL database conn = psycopg2.connect( database=database, user=user, password=password, - host="127.0.0.1", - port="5432", + host=host, + port=port, options="-c statement_timeout=30000 -c client_encoding=UTF8", ) @@ -125,7 +129,7 @@ def sql_unprotected(): column_types = [type_map[oid] for oid in column_type_oids] column_info = list(zip(column_names, column_types)) - + return results, column_info try: @@ -141,12 +145,15 @@ def sql_unprotected(): conn.close() return list(results), column_info + def split_sql_statements(query): def strip_trailing_comments(stmt): idx = len(stmt.tokens) - 1 while idx >= 0: tok = stmt.tokens[idx] - if tok.is_whitespace or sqlparse.utils.imt(tok, i=sqlparse.sql.Comment, t=sqlparse.tokens.Comment): + if tok.is_whitespace or sqlparse.utils.imt( + tok, i=sqlparse.sql.Comment, t=sqlparse.tokens.Comment + ): stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, " ") else: break @@ -159,8 +166,13 @@ def strip_trailing_semicolon(stmt): tok = stmt.tokens[idx] # we expect that trailing comments already are removed if not tok.is_whitespace: - if sqlparse.utils.imt(tok, t=sqlparse.tokens.Punctuation) and tok.value == ";": - stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, " ") + if ( + sqlparse.utils.imt(tok, t=sqlparse.tokens.Punctuation) + and tok.value == ";" + ): + stmt.tokens[idx] = sqlparse.sql.Token( + sqlparse.tokens.Whitespace, " " + ) break idx -= 1 return stmt @@ -187,15 +199,16 @@ def is_empty_statement(stmt): return [""] # if all statements were empty - return a single empty statement + def query_is_select_no_limit(query): limit_keywords = ["LIMIT", "OFFSET"] - + def find_last_keyword_idx(parsed_query): for i in reversed(range(len(parsed_query.tokens))): if parsed_query.tokens[i].ttype in sqlparse.tokens.Keyword: return i return -1 - + parsed_query = sqlparse.parse(query)[0] last_keyword_idx = find_last_keyword_idx(parsed_query) # Either invalid query or query that is not select @@ -206,10 +219,8 @@ def find_last_keyword_idx(parsed_query): return no_limit -def add_limit_to_query( - query, - limit_query = " LIMIT 1000" -): + +def add_limit_to_query(query, limit_query=" LIMIT 1000"): parsed_query = sqlparse.parse(query)[0] limit_tokens = sqlparse.parse(limit_query)[0].tokens length = len(parsed_query.tokens) @@ -220,22 +231,21 @@ def add_limit_to_query( return str(parsed_query) -def apply_auto_limit( - query_text, - limit_query = " LIMIT 1000" -): + +def apply_auto_limit(query_text, limit_query=" LIMIT 1000"): def combine_sql_statements(queries): return ";\n".join(queries) - + queries = split_sql_statements(query_text) res = [] for query in queries: if query_is_select_no_limit(query): query = add_limit_to_query(query, limit_query=limit_query) res.append(query) - + return combine_sql_statements(res) + if __name__ == "__main__": print(apply_auto_limit("SELECT * FROM restaurants LIMIT 1;")) - print(apply_auto_limit("SELECT * FROM restaurants;")) \ No newline at end of file + print(apply_auto_limit("SELECT * FROM restaurants;")) diff --git a/src/suql/prompt_continuation.py b/src/suql/prompt_continuation.py index 634a1b5..dcaf580 100644 --- a/src/suql/prompt_continuation.py +++ b/src/suql/prompt_continuation.py @@ -3,21 +3,18 @@ """ import logging -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from typing import List - import os import time import traceback +from concurrent.futures import ThreadPoolExecutor from functools import partial from threading import Thread +from typing import List from jinja2 import Environment, FileSystemLoader, select_autoescape - -from suql.utils import num_tokens_from_string from litellm import completion, completion_cost +from suql.utils import num_tokens_from_string logger = logging.getLogger(__name__) # create file handler which logs even debug messages @@ -36,11 +33,14 @@ ENABLE_CACHING = False if ENABLE_CACHING: import pymongo + mongo_client = pymongo.MongoClient("localhost", 27017) prompt_cache_db = mongo_client["open_ai_prompts"]["caches"] total_cost = 0 # in USD + + def get_total_cost(): global total_cost return total_cost @@ -75,6 +75,8 @@ def _generate( postprocess, max_tries, ban_line_break_start, + api_base=None, + api_version=None, ): # don't try multiple times if the temperature is 0, because the results will be the same if max_tries > 1 and temperature == 0: @@ -96,6 +98,8 @@ def _generate( "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "stop": stop_tokens, + "api_base": api_base, + "api_version": api_version, } generation_output = chat_completion_with_backoff(**kwargs) @@ -198,6 +202,8 @@ def llm_generate( filled_prompt=None, attempts=2, max_wait_time=None, + api_base=None, + api_version=None, ): """ filled_prompt gives direct access to the underlying model, without having to load a prompt template from a .prompt file. Used for testing. @@ -247,6 +253,8 @@ def llm_generate( postprocess, max_tries, ban_line_break_start, + api_base, + api_version, ) if success: final_result = result @@ -265,6 +273,8 @@ def llm_generate( postprocess, max_tries, ban_line_break_start, + api_version, + api_base, ) end_time = time.time() diff --git a/src/suql/sql_free_text_support/execute_free_text_sql.py b/src/suql/sql_free_text_support/execute_free_text_sql.py index b787792..a1e22e7 100644 --- a/src/suql/sql_free_text_support/execute_free_text_sql.py +++ b/src/suql/sql_free_text_support/execute_free_text_sql.py @@ -1,15 +1,15 @@ import concurrent.futures import json +import logging import random +import re import string import time import traceback -import logging -import re from collections import defaultdict from copy import deepcopy -from typing import List, Union from functools import lru_cache +from typing import List, Union import pglast import requests @@ -23,15 +23,16 @@ from sympy import Symbol, symbols from sympy.logic.boolalg import And, Not, Or, to_dnf +from suql.free_text_fcns_server import _answer from suql.postgresql_connection import execute_sql, execute_sql_with_column_info from suql.prompt_continuation import llm_generate from suql.utils import num_tokens_from_string -from suql.free_text_fcns_server import _answer # System parameters, do not modify _SET_FREE_TEXT_FCNS = ["answer"] _verified_res = {} + def _generate_random_string(length=12): characters = string.ascii_lowercase + string.digits random_string = "".join(random.choice(characters) for _ in range(length)) @@ -80,22 +81,16 @@ def __call__(self, node): def visit_FuncCall(self, ancestors, node: pglast.ast.FuncCall): for i in node.funcname: if i.sval in self._SET_FREE_TEXT_FCNS: - query_lst = list( - filter(lambda x: isinstance(x, A_Const), node.args) - ) + query_lst = list(filter(lambda x: isinstance(x, A_Const), node.args)) assert len(query_lst) == 1 query = query_lst[0].val.sval - field_lst = list( - filter(lambda x: isinstance(x, ColumnRef), node.args) - ) + field_lst = list(filter(lambda x: isinstance(x, ColumnRef), node.args)) assert len(field_lst) == 1 field = tuple(map(lambda x: x.sval, field_lst[0].fields)) - - self.res.append( - (field, query) - ) + + self.res.append((field, query)) class _TypeCastAnswer(Visitor): @@ -136,7 +131,8 @@ def visit_A_Expr(self, ancestors, node: A_Expr): def is_structural(expr): if ( isinstance(expr, FuncCall) - and ".".join(map(lambda x: x.sval, expr.funcname)) in _SET_FREE_TEXT_FCNS + and ".".join(map(lambda x: x.sval, expr.funcname)) + in _SET_FREE_TEXT_FCNS ): return False return True @@ -205,7 +201,11 @@ def __init__( create_userpswd, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base=None, + api_version=None, + host="127.0.0.1", + port="5432", ) -> None: super().__init__() self.tmp_tables = [] @@ -219,18 +219,22 @@ def __init__( self.select_userpswd = select_userpswd self.create_username = create_username self.create_userpswd = create_userpswd - + # store table_w_ids self.table_w_ids = table_w_ids - + # store default LLM self.llm_model_name = llm_model_name - + self.api_base = api_base + self.api_version = api_version + # store max verify param self.max_verify = max_verify - + # store database self.database = database + self.host = host + self.port = port def __call__(self, node): super().__call__(node) @@ -268,7 +272,9 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): self.select_userpswd, self.table_w_ids, self.llm_model_name, - self.max_verify + self.max_verify, + self.api_base, + self.api_version, ) # based on results and column_info, insert a temporary table @@ -284,6 +290,8 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): password=self.create_userpswd, commit_in_lieu_fetch=True, no_print=True, + host=self.host, + port=self.port, ) if results: @@ -308,7 +316,9 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): user=self.create_username, password=self.create_userpswd, commit_in_lieu_fetch=True, - no_print=True + no_print=True, + host=self.host, + port=self.port, ) # finally, modify the existing sql with tmp_table_name @@ -324,7 +334,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): self.fts_fields, self.select_username, self.select_userpswd, - self.llm_model_name + self.llm_model_name, ) def serialize_cache(self): @@ -360,7 +370,9 @@ def drop_tmp_tables(self): user=self.create_username, password=self.create_userpswd, commit_in_lieu_fetch=True, - no_print=True + no_print=True, + host=self.host, + port=self.port, ) @@ -426,7 +438,16 @@ def symbol2predicate(symbol_predicate): return sql_expr -def _verify(document, field, query, operator, value, llm_model_name): +def _verify( + document, + field, + query, + operator, + value, + llm_model_name, + api_base=None, + api_version=None, +): if (document, field, query, operator, value) in _verified_res: return _verified_res[(document, field, query, operator, value)] @@ -449,6 +470,8 @@ def _verify(document, field, query, operator, value, llm_model_name): stop_tokens=["\n"], max_tokens=30, postprocess=False, + api_base=api_base, + api_version=api_version, )[0] if "the answer is correct" in res.lower(): @@ -459,7 +482,9 @@ def _verify(document, field, query, operator, value, llm_model_name): return res -def _verify_single_res(doc, field_query_list, llm_model_name): +def _verify_single_res( + doc, field_query_list, llm_model_name, api_base=None, api_version=None +): # verify for each stmt, if any stmt fails to verify, exclude it all_found = True found_stmt = [] @@ -478,7 +503,9 @@ def verify_single_value(single_value, single_column_name): query, operator, value, - llm_model_name + llm_model_name, + api_base, + api_version, ) # otherwise it is a list. Go over the list until if one verifies else: @@ -490,7 +517,9 @@ def verify_single_value(single_value, single_column_name): query, operator, value, - llm_model_name + llm_model_name, + api_base, + api_version, ): res = True break @@ -527,7 +556,16 @@ def verify_single_value(single_value, single_column_name): break else: - if not _verify(doc[1][i], field, query, operator, value, llm_model_name): + if not _verify( + doc[1][i], + field, + query, + operator, + value, + llm_model_name, + api_base, + api_version, + ): all_found = False break else: @@ -600,6 +638,8 @@ def _retrieve_and_verify( table_w_ids, llm_model_name, max_verify, + api_base=None, + api_version=None, parallel=True, fetch_all=False, ): @@ -718,7 +758,9 @@ def _retrieve_and_verify( if parallel: # parallelize verification calls id_res = _parallel_filtering( - lambda x: _verify_single_res(x, field_query_list, llm_model_name), + lambda x: _verify_single_res( + x, field_query_list, llm_model_name, api_base, api_version + ), parsed_result, limit, enforce_ordering=True if node.sortClause is not None else False, @@ -726,7 +768,9 @@ def _retrieve_and_verify( else: id_res = set() for each_res in parsed_result: - if _verify_single_res(each_res, field_query_list, llm_model_name): + if _verify_single_res( + each_res, field_query_list, llm_model_name, api_base, api_version + ): if isinstance(each_res[0], list): id_res.update(each_res[0]) else: @@ -860,7 +904,9 @@ def __init__( fts_fields, select_username, select_userpswd, - llm_model_name + llm_model_name, + api_base=None, + api_version=None, ) -> None: super().__init__() self.node = node @@ -870,6 +916,8 @@ def __init__( self.select_username = select_username self.select_userpswd = select_userpswd self.llm_model_name = llm_model_name + self.api_base = api_base + self.api_version = api_version def __call__(self, node): super().__call__(node) @@ -917,7 +965,9 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): and isinstance(self.node.fromClause[0], RangeVar) and node.name[0].sval in ["~~", "~~*", "="] ): - n_field_name, n_value_name = _get_a_expr_field_value(node, no_check=True) + n_field_name, n_value_name = _get_a_expr_field_value( + node, no_check=True + ) if ( table_name == self.node.fromClause[0].relname and field_name == n_field_name @@ -959,6 +1009,8 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): unprotected=True, user=self.select_username, password=self.select_userpswd, + host=self.host, + port=self.port, ) # it is possible if there is a type error # e.g. "Passengers ( 2017 )" = '490,000', but "Passengers ( 2017 )" is actually of type int @@ -971,6 +1023,8 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): self.database, user=self.select_username, password=self.select_userpswd, + host=self.host, + port=self.port, ) except psyconpg2Error: logging.info( @@ -1053,6 +1107,8 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): self.database, user=self.select_username, password=self.select_userpswd, + host=self.host, + port=self.port, ) # TODO deal with list problems? field_value_choices = list(map(lambda x: x[0], field_value_choices)) @@ -1074,6 +1130,8 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): stop_tokens=["\n"], max_tokens=100, postprocess=False, + api_base=self.api_base, + api_version=self.api_version, )[0] if res in field_value_choices: _replace_a_expr_field(node, ancestors, String(sval=(res))) @@ -1096,7 +1154,9 @@ def _classify_db_fields( fts_fields: List, select_username: str, select_userpswd: str, - llm_model_name: str + llm_model_name: str, + api_base=None, + api_version=None, ): # we expect all atomic predicates under `predicate` to only involve stru fields # (no `answer` function) @@ -1109,7 +1169,9 @@ def _classify_db_fields( fts_fields, select_username, select_userpswd, - llm_model_name + llm_model_name, + api_base, + api_version, ) visitor(node) @@ -1137,14 +1199,13 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef): # the same field appears twice, this means that the original syntax is problematic break res = (String(sval=f"{table_name}^{node.fields[0].sval}"),) - + # do not replace if None, b/c this should be an aliased field if res is not None: node.fields = res -def _extract_recursive_joins( - fromClause: JoinExpr -): + +def _extract_recursive_joins(fromClause: JoinExpr): """ A FROM clause of a SelectStmt could have multiple joins. This functions searilizes the joins and returns them as a list. @@ -1154,14 +1215,14 @@ def _extract_recursive_joins( res.append(fromClause.larg) if isinstance(fromClause.rarg, RangeVar): res.append(fromClause.rarg) - + if isinstance(fromClause.larg, JoinExpr): res.extend(_extract_recursive_joins(fromClause.larg)) if isinstance(fromClause.rarg, JoinExpr): res.extend(_extract_recursive_joins(fromClause.rarg)) - + return res - + def _execute_structural_sql( original_node: SelectStmt, @@ -1171,9 +1232,13 @@ def _execute_structural_sql( fts_fields: List, select_username: str, select_userpswd: str, - llm_model_name: str + llm_model_name: str, + api_base=None, + api_version=None, + host="127.0.0.1", + port="5432", ): - _ = RawStream()(original_node) # RawStream takes care of some issue, to investigate + _ = RawStream()(original_node) # RawStream takes care of some issue, to investigate node = deepcopy(original_node) # change projection to include everything # there are a couple of cases here @@ -1187,10 +1252,17 @@ def _execute_structural_sql( for table in _extract_recursive_joins(node.fromClause[0]): # find out what columns this table has _, columns = execute_sql_with_column_info( - RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))), + RawStream()( + SelectStmt( + fromClause=(table,), + targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),), + ) + ), database, select_username, select_userpswd, + host=host, + port=port, ) # give the projection fields new names projection_table_name = ( @@ -1226,10 +1298,17 @@ def _execute_structural_sql( for table in node.fromClause: # find out what columns this table has _, columns = execute_sql_with_column_info( - RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))), + RawStream()( + SelectStmt( + fromClause=(table,), + targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),), + ) + ), database, select_username, select_userpswd, + host=host, + port=port, ) # give the projection fields new names projection_table_name = ( @@ -1279,15 +1358,19 @@ def _execute_structural_sql( fts_fields, select_username, select_userpswd, - llm_model_name + llm_model_name, + api_base, + api_version, ) sql = RawStream()(node) return execute_sql_with_column_info( - sql, + sql, database, user=select_username, - password=select_userpswd + password=select_userpswd, + host=host, + port=port, ) @@ -1300,7 +1383,9 @@ def _execute_free_text_queries( embedding_server_address, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ): # the predicate should only contain an atomic unstructural query # or an AND of multiple unstructural query (NOT of an unstructural query is considered to be atmoic) @@ -1329,9 +1414,9 @@ def extract_tuple_value(v: List[A_Const]): return tuple(res) def breakdown_unstructural_query(predicate: A_Expr): - assert _if_contains_free_text_fcn(predicate.lexpr) or _if_contains_free_text_fcn( - predicate.rexpr - ) + assert _if_contains_free_text_fcn( + predicate.lexpr + ) or _if_contains_free_text_fcn(predicate.rexpr) if _if_contains_free_text_fcn(predicate.lexpr) and _if_contains_free_text_fcn( predicate.rexpr ): @@ -1411,7 +1496,9 @@ def breakdown_unstructural_query(predicate: A_Expr): embedding_server_address, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ), column_info, ) @@ -1431,7 +1518,9 @@ def breakdown_unstructural_query(predicate: A_Expr): embedding_server_address, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ), column_info, ) @@ -1456,7 +1545,11 @@ def _execute_and( select_userpswd, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base=None, + api_version=None, + host="127.0.0.1", + port="5432", ): # there should not exist any OR expression inside sql_dnf_predicates @@ -1486,7 +1579,10 @@ def _execute_and( fts_fields, select_username, select_userpswd, - llm_model_name + llm_model_name, + api_version, + host=host, + port=port, ) free_text_predicates = tuple( @@ -1508,7 +1604,9 @@ def _execute_and( embedding_server_address, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ) elif isinstance(sql_dnf_predicates, A_Expr) or ( @@ -1524,7 +1622,9 @@ def _execute_and( fts_fields, select_username, select_userpswd, - llm_model_name + llm_model_name, + api_base, + api_version, ) else: all_results, column_info = _execute_structural_sql( @@ -1535,7 +1635,11 @@ def _execute_and( fts_fields, select_username, select_userpswd, - llm_model_name + llm_model_name, + api_base, + api_version, + host=host, + port=port, ) return _execute_free_text_queries( node, @@ -1546,7 +1650,9 @@ def _execute_and( embedding_server_address, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ) @@ -1560,7 +1666,9 @@ def _analyze_SelectStmt( select_userpswd: str, table_w_ids: dict, llm_model_name: str, - max_verify: str + max_verify: str, + api_base=None, + api_version=None, ): limit = node.limitCount.val.ival if node.limitCount else -1 sql_dnf_predicates = _convert2dnf(node.whereClause) @@ -1588,7 +1696,9 @@ def _analyze_SelectStmt( select_userpswd, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ) res.extend(choice_res) @@ -1614,7 +1724,9 @@ def _analyze_SelectStmt( select_userpswd, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ) elif isinstance(sql_dnf_predicates, A_Expr) or ( @@ -1633,7 +1745,9 @@ def _analyze_SelectStmt( select_userpswd, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base, + api_version, ) else: raise ValueError( @@ -1647,10 +1761,10 @@ def _parse_standalone_answer(suql): # Define a regular expression pattern to match the required format # \s* allows for any number of whitespaces around the parentheses pattern = r"\s*answer\s*\(\s*([a-zA-Z_0-9]+)\s*,\s*['\"](.+?)['\"]\s*\)\s*" - + # Use the re.match function to check if the entire string matches the pattern match = re.match(pattern, suql) - + # If a match is found, return the captured groups: source and query if match: return match.group(1), match.group(2) @@ -1673,36 +1787,33 @@ def _execute_standalone_answer(suql, source_file_mapping): source, query = _parse_standalone_answer(suql) if source not in source_file_mapping: return None - + source_content = _read_source_file(source_file_mapping[source]) - + return _answer(source_content, query) + def _check_predicate_exist(a_expr: A_Expr, field_name: str): if isinstance(a_expr.lexpr, ColumnRef): for i in a_expr.lexpr.fields: if isinstance(i, String) and i.sval == field_name: return True - + if isinstance(a_expr.rexpr, ColumnRef): for i in a_expr.rexpr.fields: if isinstance(i, String) and i.sval == field_name: return True - + return False class _RequiredParamMappingVisitor(Visitor): - def __init__( - self, - required_params_mapping - ) -> None: + def __init__(self, required_params_mapping) -> None: super().__init__() self.required_params_mapping = required_params_mapping - self.missing_params = defaultdict(set) - - def visit_SelectStmt(self, ancestors, node: SelectStmt): + self.missing_params = defaultdict(set) + def visit_SelectStmt(self, ancestors, node: SelectStmt): def check_a_expr_or_and_expr(_dnf_predicate, _field): if isinstance(_dnf_predicate, A_Expr): return _check_predicate_exist(_dnf_predicate, _field) @@ -1717,20 +1828,24 @@ def check_a_expr_or_and_expr(_dnf_predicate, _field): if _check_predicate_exist(i, _field): found = True break - + return found - + return False - - + for table in node.fromClause: - if isinstance(table, RangeVar) and table.relname in self.required_params_mapping: + if ( + isinstance(table, RangeVar) + and table.relname in self.required_params_mapping + ): assert type(self.required_params_mapping[table.relname]) == list - + if not node.whereClause: - self.missing_params[table.relname].update(self.required_params_mapping[table.relname]) + self.missing_params[table.relname].update( + self.required_params_mapping[table.relname] + ) continue - + dnf_predicate = _convert2dnf(node.whereClause) if ( @@ -1738,7 +1853,10 @@ def check_a_expr_or_and_expr(_dnf_predicate, _field): and dnf_predicate.boolop == BoolExprType.OR_EXPR ): for field in self.required_params_mapping[table.relname]: - if not all(check_a_expr_or_and_expr(i, field) for i in dnf_predicate.args): + if not all( + check_a_expr_or_and_expr(i, field) + for i in dnf_predicate.args + ): self.missing_params[table.relname].add(field) else: # target condition: @@ -1750,23 +1868,23 @@ def check_a_expr_or_and_expr(_dnf_predicate, _field): for field in self.required_params_mapping[table.relname]: if not check_a_expr_or_and_expr(dnf_predicate, field): self.missing_params[table.relname].add(field) - + def _check_required_params(suql, required_params_mapping): """ Check whether all required parameters exist in the `suql`. - + # Parameters: `suql` (str): The to-be-executed suql query. - + `required_params_mapping` (Dict(str -> List[str]), optional): *Experimental feature*: a dictionary mapping from table names to a list of "required" parameters for the tables. The SUQL compiler will check whether the SUQL query contains all required parameters (i.e., whether for each such table there exists a `WHERE` clause with the required parameter). - + # Returns: `if_all_exist` (bool): whether all required parameters exist. - + `missing_params` (Dict(str -> List[str]): a mapping from table names to a list of required missing parameters. """ # try except handles stand alone answer functions and other parsing exceptions @@ -1774,15 +1892,17 @@ def _check_required_params(suql, required_params_mapping): root = parse_sql(suql) except Exception: return False, required_params_mapping - + visitor = _RequiredParamMappingVisitor(required_params_mapping) visitor(root) - + if visitor.missing_params: - return False, {key: list(value) for key, value in visitor.missing_params.items()} + return False, { + key: list(value) for key, value in visitor.missing_params.items() + } else: return True, {} - + def suql_execute( suql, @@ -1801,44 +1921,49 @@ def suql_execute( create_username="creator_role", create_userpswd="creator_role", source_file_mapping={}, + host="127.0.0.1", + port="5432", + # used for azure openai + api_base=None, + api_version=None, ): """ Main entry point to the SUQL Python-based compiler. # Parameters: `suql` (str): The to-be-executed suql query. - + `table_w_ids` (dict): A dictionary where each key is a table name, and each value is the corresponding unique ID column name in this table, e.g., `table_w_ids = {"restaurants": "_id"}`, meaning that the relevant tables to the SUQL compiler include only the `restaurants` table, which has unique ID column `_id`. - + `database` (str): The name of the PostgreSQL database to execute the query. - + `fts_fields` (List[str], optional): Fields that should use PostgreSQL's Full Text Search (FTS) operators; The SUQL compiler would change certain string operators like "=" to use PostgreSQL's FTS operators. It uses `websearch_to_tsquery` and the `@@` operator to match against these fields. - + `llm_model_name` (str, optional): The LLM to be used by the SUQL compiler. Defaults to `gpt-3.5-turbo-0125`. - + `max_verify` (str): For each LIMIT x clause, `max_verify * x` results will be retrieved together from the embedding model for LLM to verify. Defaults to 20. - + `loggings` (str, optional): Prefix for error case loggings. Errors are written to a "_suql_error_log.txt" file by default. `log_filename` (str, optional): Logging file name for the SUQL compiler. If not provided, logging is disabled. - + `disable_try_catch` (bool, optional): whether to disable try-catch (errors would directly propagate to caller). - + `embedding_server_address` (str, optional): the embedding server address. Defaults to 'http://127.0.0.1:8501'. - + `select_username` (str, optional): user name with select privilege in db. Defaults to "select_user". - + `select_userpswd` (str, optional): above user's password with select privilege in db. Defaults to "select_user". - + `create_username` (str, optional): user name with create privilege in db. Defaults to "creator_role". - + `create_userpswd` (str, optional): above user's password with create privilege in db. Defaults to "creator_role". `source_file_mapping` (Dict(str -> str), optional): *Experimental feature*: a dictionary mapping from variable @@ -1849,9 +1974,9 @@ def suql_execute( # Returns: `results` (List[[*]]): A list of returned database results. Each inner list stores a row of returned result. - + `column_names` (List[str]): A list of database column names in the same order as `results`. - + `cache` (Dict()): Debugging information from the SUQL compiler. # Example: @@ -1871,13 +1996,12 @@ def suql_execute( FTS helps with such cases. """ if log_filename: - logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - handlers=[ - logging.FileHandler(log_filename), - logging.StreamHandler() - ]) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.FileHandler(log_filename), logging.StreamHandler()], + ) else: logging.basicConfig(level=logging.CRITICAL + 1) @@ -1900,6 +2024,10 @@ def suql_execute( select_userpswd, create_username, create_userpswd, + host=host, + port=port, + api_base=api_base, + api_version=api_version, ) if results == []: return results, column_names, cache @@ -1934,6 +2062,10 @@ def _suql_execute_single( select_userpswd, create_username, create_userpswd, + host="127.0.0.1", + port="5432", + api_base=None, + api_version=None, ): results = [] column_names = [] @@ -1950,7 +2082,11 @@ def _suql_execute_single( create_userpswd, table_w_ids, llm_model_name, - max_verify + max_verify, + api_base=api_base, + api_version=api_version, + host=host, + port=port, ) root = parse_sql(suql) visitor(root) @@ -1963,13 +2099,15 @@ def _suql_execute_single( user=select_username, password=select_userpswd, no_print=True, - unprotected=disable_try_catch_sql + unprotected=disable_try_catch_sql, + host=host, + port=port, ) except Exception as err: if disable_try_catch: raise err with open("_suql_error_log.txt", "a") as file: - file.write(f"==============\n") + file.write("==============\n") file.write(f"{loggings}\n") file.write(f"{suql}\n") file.write(f"{str(err)}\n") @@ -1994,9 +2132,9 @@ def _suql_execute_single( # "yelp_general_info": "/home/harshit/DialogueForms/src/genie/domains/yelpbot/yelp_general_info.txt" # }, disable_try_catch=True, - disable_try_catch_all_sql=True + disable_try_catch_all_sql=True, ) - + print(results) exit(0) # print(suql_execute(sql, disable_try_catch=True, fts_fields=[("restaurants", "name")] )[0])