-
Notifications
You must be signed in to change notification settings - Fork 43
/
chat.py
133 lines (112 loc) · 4.65 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import openai
import streamlit as st
from audio_recorder_streamlit import audio_recorder
from elevenlabs import generate
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import DeepLake
from streamlit_chat import message
from dotenv import load_dotenv
# Load environment variables from the .env file
load_dotenv()
# Constants
TEMP_AUDIO_PATH = "temp_audio.wav"
AUDIO_FORMAT = "audio/wav"
# Load environment variables from .env file and return the keys
openai.api_key = os.environ.get('OPENAI_API_KEY')
eleven_api_key = os.environ.get('ELEVEN_API_KEY')
active_loop_data_set_path = os.environ.get('DEEPLAKE_DATASET_PATH')
# Load embeddings and DeepLake database
def load_embeddings_and_database(active_loop_data_set_path):
embeddings = OpenAIEmbeddings()
db = DeepLake(
dataset_path=active_loop_data_set_path,
read_only=True,
embedding_function=embeddings
)
return db
# Transcribe audio using OpenAI Whisper API
def transcribe_audio(audio_file_path, openai_key):
openai.api_key = openai_key
try:
with open(audio_file_path, "rb") as audio_file:
response = openai.Audio.transcribe("whisper-1", audio_file)
return response["text"]
except Exception as e:
print(f"Error calling Whisper API: {str(e)}")
return None
# Record audio using audio_recorder and transcribe using transcribe_audio
def record_and_transcribe_audio():
audio_bytes = audio_recorder()
transcription = None
if audio_bytes:
st.audio(audio_bytes, format=AUDIO_FORMAT)
with open(TEMP_AUDIO_PATH, "wb") as f:
f.write(audio_bytes)
if st.button("Transcribe"):
transcription = transcribe_audio(TEMP_AUDIO_PATH, openai.api_key)
os.remove(TEMP_AUDIO_PATH)
display_transcription(transcription)
return transcription
# Display the transcription of the audio on the app
def display_transcription(transcription):
if transcription:
st.write(f"Transcription: {transcription}")
with open("audio_transcription.txt", "w+") as f:
f.write(transcription)
else:
st.write("Error transcribing audio.")
# Get user input from Streamlit text input field
def get_user_input(transcription):
return st.text_input("", value=transcription if transcription else "", key="input")
# Search the database for a response based on the user's query
def search_db(user_input, db):
print(user_input)
retriever = db.as_retriever()
retriever.search_kwargs['distance_metric'] = 'cos'
retriever.search_kwargs['fetch_k'] = 100
retriever.search_kwargs['maximal_marginal_relevance'] = True
retriever.search_kwargs['k'] = 10
model = ChatOpenAI(model='gpt-3.5-turbo')
qa = RetrievalQA.from_llm(model, retriever=retriever, return_source_documents=True)
return qa({'query': user_input})
# Display conversation history using Streamlit messages
def display_conversation(history):
for i in range(len(history["generated"])):
message(history["past"][i], is_user=True, key=str(i) + "_user")
message(history["generated"][i],key=str(i))
#Voice using Eleven API
voice= "Bella"
text= history["generated"][i]
audio = generate(text=text, voice=voice,api_key=eleven_api_key)
st.audio(audio, format='audio/mp3')
# Main function to run the app
def main():
# Initialize Streamlit app with a title
st.write("# JarvisBase 🧙")
# Load embeddings and the DeepLake database
db = load_embeddings_and_database(active_loop_data_set_path)
# Record and transcribe audio
transcription = record_and_transcribe_audio()
# Get user input from text input or audio transcription
user_input = get_user_input(transcription)
# Initialize session state for generated responses and past messages
if "generated" not in st.session_state:
st.session_state["generated"] = ["I am ready to help you"]
if "past" not in st.session_state:
st.session_state["past"] = ["Hey there!"]
# Search the database for a response based on user input and update session state
if user_input:
output = search_db(user_input, db)
print(output['source_documents'])
st.session_state.past.append(user_input)
response = str(output["result"])
st.session_state.generated.append(response)
# Display conversation history using Streamlit messages
if st.session_state["generated"]:
display_conversation(st.session_state)
# Run the main function when the script is executed
if __name__ == "__main__":
main()