diff --git a/app.py b/app.py index de7f070..310ee50 100644 --- a/app.py +++ b/app.py @@ -17,26 +17,38 @@ exit() df = pd.read_csv(uploaded_file) -st.subheader("Head data") -st.write(df.head()) +first_few_entries = df.head(2) +print(first_few_entries) +st.subheader("Data") +st.write(df) df.to_sql(TABLE_NAME, conn, if_exists='replace', index=False) + +table_info = run_query(conn, f"PRAGMA table_info({TABLE_NAME})") +filtered_info = [(col[1], col[2]) for col in table_info] +columns = ["Column Name", "Data Type"] +df_table_info = pd.DataFrame(filtered_info, columns=columns) +print("df info: ", df_table_info) +table_info_string = df_table_info.to_string(index=False) + user_query = st.text_input( "Enter your query", placeholder="Eg: Which merchant has the maximum transactions and by how many?", ) print(user_query) +if not len(user_query): + exit() -# sql_query = f"SELECT AVG(cc_num) FROM {TABLE_NAME}" -table_info = run_query(conn, f"PRAGMA table_info({TABLE_NAME})") -table_info_string = pd.DataFrame(table_info, columns=["Column Index","Column Name", "Data Type", "Nullable", "Default Value", "Primary Key"]).to_string(index=False) -sql_query = get_sql_for(user_query, table_info_string) + + +sql_query = get_sql_for(user_query, table_info_string, first_few_entries) db_result = run_query(conn, sql_query) +print("db res = ", db_result) nlp_result = get_nlp_result_for(user_query, db_result) print(type(nlp_result)) diff --git a/src/llm_utils.py b/src/llm_utils.py index a09af1c..7f7b6c1 100644 --- a/src/llm_utils.py +++ b/src/llm_utils.py @@ -1,14 +1,64 @@ -from src.constants import TABLE_NAME +# from src.constants import TABLE_NAME +TABLE_NAME = "tb" -def get_sql_for(user_query, table_info_string ): + +import requests +import json + +def execute_with_ollama(query): + print({"Query: ", query}) + payload = { + "model": "mistral", + "format": "json", + "stream": False, + "messages": [ + {"role": "user", "content": query + "\nOutput only SQL in JSON"} + ] + } + + payload_json = json.dumps(payload) + url = 'http://localhost:11434/api/chat' + + try: + response = requests.post(url, data=payload_json) + + if response.status_code == 200: + response_data = response.json() + print("Reponse data = ", response_data) + return response_data['message']['content'] + else: + print(f"LLM Request failed with status code {response.status_code}") + return None + except requests.RequestException as e: + print(f"Request exception: {e}") + return None + +# query = "why is the sky blue? Respond to the point in JSON" +# result = execute_with_ollama(query) +# print("Response:", result) + + +def get_sql_for(user_query, table_info_string, first_few_entries ): llm_query= f""" I have an SQL tabled called {TABLE_NAME}, - having the following structure:\n {table_info_string} + having the following structure: + {table_info_string} + --- + Here are the first few entries: + {first_few_entries} --- - Please generate only the SQL based on this table for the following query: + Please generate only the simplest SQL based on this table for the following query. Here is the query:: {user_query} + + --- + + eg: {{ + "sql": "SELECT * FROM {TABLE_NAME}" + "error": null + }} """ - print(llm_query) + llm_response = execute_with_ollama(llm_query) + print("LLM resopnse = ", llm_response) # connect with llm and get sql_query @@ -20,9 +70,15 @@ def get_sql_for(user_query, table_info_string ): LIMIT 1 """ - return dummy_sql_query + print(type(llm_response)) + print(type(json.loads(llm_response))) + llm_response = json.loads(llm_response) + print("JSON response= ", llm_response) + print(llm_response["sql"]) + # return dummy_sql_query + return llm_response["sql"] def get_nlp_result_for(user_query, db_result): print(user_query, db_result) - return db_result[0] \ No newline at end of file + return db_result[0][0] \ No newline at end of file diff --git a/src/sql_utils.py b/src/sql_utils.py index 9edeaea..79ed536 100644 --- a/src/sql_utils.py +++ b/src/sql_utils.py @@ -2,6 +2,7 @@ import sqlite3 def run_query(conn, query): + query = query.replace('\\', '') try: cursor = conn.execute(query) result = cursor.fetchall()