Skip to content

Commit

Permalink
works with ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
aldrinjenson committed Dec 16, 2023
1 parent 8c3c8e3 commit 02d9611
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 13 deletions.
24 changes: 18 additions & 6 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,38 @@
exit()

df = pd.read_csv(uploaded_file)
st.subheader("Head data")
st.write(df.head())
first_few_entries = df.head(2)
print(first_few_entries)
st.subheader("Data")
st.write(df)

df.to_sql(TABLE_NAME, conn, if_exists='replace', index=False)


table_info = run_query(conn, f"PRAGMA table_info({TABLE_NAME})")
filtered_info = [(col[1], col[2]) for col in table_info]
columns = ["Column Name", "Data Type"]
df_table_info = pd.DataFrame(filtered_info, columns=columns)
print("df info: ", df_table_info)
table_info_string = df_table_info.to_string(index=False)

user_query = st.text_input(
"Enter your query",
placeholder="Eg: Which merchant has the maximum transactions and by how many?",
)

print(user_query)
if not len(user_query):
exit()


# sql_query = f"SELECT AVG(cc_num) FROM {TABLE_NAME}"

table_info = run_query(conn, f"PRAGMA table_info({TABLE_NAME})")
table_info_string = pd.DataFrame(table_info, columns=["Column Index","Column Name", "Data Type", "Nullable", "Default Value", "Primary Key"]).to_string(index=False)
sql_query = get_sql_for(user_query, table_info_string)


sql_query = get_sql_for(user_query, table_info_string, first_few_entries)

db_result = run_query(conn, sql_query)
print("db res = ", db_result)
nlp_result = get_nlp_result_for(user_query, db_result)
print(type(nlp_result))

Expand Down
70 changes: 63 additions & 7 deletions src/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,64 @@
from src.constants import TABLE_NAME
# from src.constants import TABLE_NAME
TABLE_NAME = "tb"

def get_sql_for(user_query, table_info_string ):

import requests
import json

def execute_with_ollama(query):
print({"Query: ", query})
payload = {
"model": "mistral",
"format": "json",
"stream": False,
"messages": [
{"role": "user", "content": query + "\nOutput only SQL in JSON"}
]
}

payload_json = json.dumps(payload)
url = 'http://localhost:11434/api/chat'

try:
response = requests.post(url, data=payload_json)

if response.status_code == 200:
response_data = response.json()
print("Reponse data = ", response_data)
return response_data['message']['content']
else:
print(f"LLM Request failed with status code {response.status_code}")
return None
except requests.RequestException as e:
print(f"Request exception: {e}")
return None

# query = "why is the sky blue? Respond to the point in JSON"
# result = execute_with_ollama(query)
# print("Response:", result)


def get_sql_for(user_query, table_info_string, first_few_entries ):
llm_query= f"""
I have an SQL tabled called {TABLE_NAME},
having the following structure:\n {table_info_string}
having the following structure:
{table_info_string}
---
Here are the first few entries:
{first_few_entries}
---
Please generate only the SQL based on this table for the following query:
Please generate only the simplest SQL based on this table for the following query. Here is the query::
{user_query}
---
eg: {{
"sql": "SELECT * FROM {TABLE_NAME}"
"error": null
}}
"""
print(llm_query)
llm_response = execute_with_ollama(llm_query)
print("LLM resopnse = ", llm_response)

# connect with llm and get sql_query

Expand All @@ -20,9 +70,15 @@ def get_sql_for(user_query, table_info_string ):
LIMIT 1
"""

return dummy_sql_query
print(type(llm_response))
print(type(json.loads(llm_response)))
llm_response = json.loads(llm_response)
print("JSON response= ", llm_response)
print(llm_response["sql"])
# return dummy_sql_query
return llm_response["sql"]


def get_nlp_result_for(user_query, db_result):
print(user_query, db_result)
return db_result[0]
return db_result[0][0]
1 change: 1 addition & 0 deletions src/sql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sqlite3

def run_query(conn, query):
query = query.replace('\\', '')
try:
cursor = conn.execute(query)
result = cursor.fetchall()
Expand Down

0 comments on commit 02d9611

Please sign in to comment.