Skip to content

Commit

Permalink
Make it work with both openai and ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
aldrinjenson committed Dec 16, 2023
1 parent d974875 commit 24b9980
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 34 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ data/*
.DS_STORE
*.db
*.csv
*__pycache__
*__pycache__
.env
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from src.constants import TABLE_NAME, DATABASE_NAME
import streamlit as st
import pandas as pd
import sqlite3
from src.constants import TABLE_NAME, DATABASE_NAME
from src.sql_utils import run_query
from src.llm_utils import get_sql_for, get_nlp_result_for
from src.streamlit_utils import cleanup
Expand All @@ -17,7 +17,7 @@
exit()

df = pd.read_csv(uploaded_file)
first_few_entries = df.head(4).to_string()
first_few_entries = df.head(2).to_string()
st.subheader("Data")
st.write(df)

Expand Down
8 changes: 7 additions & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from dotenv import load_dotenv
load_dotenv()
import os

TABLE_NAME = "tb"
DATABASE_NAME = "db.db"
DATABASE_NAME = "db.db"

OPENAI_API_KEY=os.getenv('OPENAI_API_KEY')
57 changes: 57 additions & 0 deletions src/llm_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import requests
import json
from openai import OpenAI
from src.constants import OPENAI_API_KEY

print(OPENAI_API_KEY)
client = OpenAI()
client.api_key = OPENAI_API_KEY

def execute_with_ollama(query):
payload = {
"model": "mistral",
"format": "json",
"stream": False,
"messages": [
{"role": "user", "content": query}
]
}

payload_json = json.dumps(payload)
url = 'http://localhost:11434/api/chat'

try:
response = requests.post(url, data=payload_json)

if response.status_code == 200:
response_data = response.json()
response_data = json.loads(response_data['message']['content'])
print(response_data)
return response_data
else:
print(f"LLM Request failed with status code {response.status_code}")
return None
except requests.RequestException as e:
print(f"Request exception: {e}")
return None


def execute_with_openai(query):
completion = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
messages=[
{
"role": "system",
"content":query
},
],
temperature=0.7,
max_tokens=64,
top_p=1,
response_format={ "type": "json_object" },
)
response = completion.choices[0].message.content;
print(response)
response = json.loads(response)
return response

46 changes: 16 additions & 30 deletions src/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,22 @@
from src.constants import TABLE_NAME
from src.llm_models import execute_with_openai, execute_with_ollama
import json


import requests
import json
MODEL="mistral"
# MODEL="openai"

def execute_with_ollama(query):
def execute_with_llm(query):
print({"Query: ", query})
payload = {
"model": "mistral",
"format": "json",
"stream": False,
"messages": [
{"role": "user", "content": query}
]
}

payload_json = json.dumps(payload)
url = 'http://localhost:11434/api/chat'

try:
response = requests.post(url, data=payload_json)
if MODEL == 'openai':
response = execute_with_openai(query)
elif MODEL == 'mistral':
response = execute_with_ollama(query)
else:
return Exception("Invalid model specified")
return response

if response.status_code == 200:
response_data = response.json()
return response_data['message']['content']
else:
print(f"LLM Request failed with status code {response.status_code}")
return None
except requests.RequestException as e:
print(f"Request exception: {e}")
return None

def get_sql_for(user_query, table_info_string, first_few_entries ):
llm_query= f"""
Expand All @@ -53,9 +40,8 @@ def get_sql_for(user_query, table_info_string, first_few_entries ):
"error": null
}}
"""
llm_response = execute_with_ollama(llm_query)
llm_response = execute_with_llm(llm_query)

llm_response = json.loads(llm_response)
print(llm_response["sql"])
return llm_response["sql"]

Expand All @@ -66,9 +52,9 @@ def get_nlp_result_for(user_query, sql_query, db_result):
After running an SQL query of {sql_query}, the result of {db_result} was obtained.
Based on this, in natural language answer the question: {user_query}.
If you cannot answer based on this, directly say so. Don't mention anything about SQL.
Directly answer the user to the point.
Directly answer the user to the point Output in the following json format
{{response: <your response>}}.
"""
nlp_result = execute_with_ollama(llm_query)
nlp_result = execute_with_llm(llm_query)
print(nlp_result)
nlp_result = json.loads(nlp_result)
return nlp_result

0 comments on commit 24b9980

Please sign in to comment.