Skip to content

Commit

Permalink
feat: WIP, voice interact with bot
Browse files Browse the repository at this point in the history
  • Loading branch information
madawei2699 committed Mar 22, 2023
1 parent 1523c7b commit f0b1c00
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 15 deletions.
4 changes: 3 additions & 1 deletion .env_sample
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ export CF_ACCESS_CLIENT_SECRET='your_client_secret'
export PHANTOMJSCLOUD_API_KEY='your_api_key'
export OPENAI_API_KEY='your_api_key'
export SLACK_TOKEN='your_slack_token'
export SLACK_SIGNING_SECRET='your_slack_signing_secret'
export SLACK_SIGNING_SECRET='your_slack_signing_secret'
export SPEECH_KEY='your_azure_speech_key'
export SPEECH_REGION='your_azure_speech_region'
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ For now it is in development, but you can try it out by join this [channel](http
- Use [Google Vision](https://cloud.google.com/vision/docs/pdf) to handle the PDF reading
- [ ] Image
- may use GPT4
- [ ] Support voice reading ~~with self-hosting [whisper](https://github.com/aarnphm/whispercpp)~~
- (whisper -> chatGPT -> azure text2speech) to play language speaking practices 💥 🚩
- [x] Support voice reading ~~with self-hosting [whisper](https://github.com/aarnphm/whispercpp)~~
- (whisper -> chatGPT -> azure text2speech) to play language speaking practices 💥
- [ ] Integrated with Azure OpenAI Service
- [ ] User access limit
- Limit the number of requests to bot per user per day to save the cost
Expand Down
81 changes: 76 additions & 5 deletions app/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,49 @@
import os
import logging
import hashlib
import uuid
import openai
from llama_index import GPTSimpleVectorIndex, LLMPredictor, RssReader, SimpleDirectoryReader
from llama_index.prompts.prompts import QuestionAnswerPrompt
from llama_index.readers.schema.base import Document
from langchain.chat_models import ChatOpenAI
from azure.cognitiveservices.speech import SpeechConfig, SpeechSynthesizer, ResultReason, CancellationReason, SpeechSynthesisOutputFormat, AudioDataStream
from azure.cognitiveservices.speech.audio import AudioOutputConfig

from app.fetch_web_post import get_urls, scrape_website, scrape_website_by_phantomjscloud

OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
SPEECH_KEY = os.environ.get('SPEECH_KEY')
SPEECH_REGION = os.environ.get('SPEECH_REGION')
openai.api_key = OPENAI_API_KEY

llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.2, model_name="gpt-3.5-turbo"))
llm_predictor = LLMPredictor(llm=ChatOpenAI(
temperature=0.2, model_name="gpt-3.5-turbo"))

index_cache_web_dir = '/tmp/myGPTReader/cache_web/'
index_cache_voice_dir = '/tmp/myGPTReader/voice/'
index_cache_file_dir = '/data/myGPTReader/file/'

if not os.path.exists(index_cache_web_dir):
os.makedirs(index_cache_web_dir)

if not os.path.exists(index_cache_voice_dir):
os.makedirs(index_cache_voice_dir)

if not os.path.exists(index_cache_file_dir):
os.makedirs(index_cache_file_dir)


def get_unique_md5(urls):
urls_str = ''.join(sorted(urls))
hashed_str = hashlib.md5(urls_str.encode('utf-8')).hexdigest()
return hashed_str


def format_dialog_messages(messages):
return "\n".join(messages)


def get_documents_from_urls(urls):
documents = []
for url in urls['page_urls']:
Expand All @@ -46,6 +59,7 @@ def get_documents_from_urls(urls):
documents.append(document)
return documents


def get_answer_from_chatGPT(messages):
dialog_messages = format_dialog_messages(messages)
logging.info('=====> Use chatGPT to answer!')
Expand All @@ -57,6 +71,7 @@ def get_answer_from_chatGPT(messages):
logging.info(completion.usage)
return completion.choices[0].message.content


QUESTION_ANSWER_PROMPT_TMPL = (
"Context information is below. \n"
"---------------------\n"
Expand All @@ -66,20 +81,25 @@ def get_answer_from_chatGPT(messages):
)
QUESTION_ANSWER_PROMPT = QuestionAnswerPrompt(QUESTION_ANSWER_PROMPT_TMPL)


def get_index_from_web_cache(name):
if not os.path.exists(index_cache_web_dir + name):
return None
index = GPTSimpleVectorIndex.load_from_disk(index_cache_web_dir + name)
logging.info(f"=====> Get index from web cache: {index_cache_web_dir + name}")
logging.info(
f"=====> Get index from web cache: {index_cache_web_dir + name}")
return index


def get_index_from_file_cache(name):
if not os.path.exists(index_cache_file_dir + name):
return None
index = GPTSimpleVectorIndex.load_from_disk(index_cache_file_dir + name)
logging.info(f"=====> Get index from file cache: {index_cache_file_dir + name}")
logging.info(
f"=====> Get index from file cache: {index_cache_file_dir + name}")
return index


def get_answer_from_llama_web(messages, urls):
dialog_messages = format_dialog_messages(messages)
logging.info('=====> Use llama web with chatGPT to answer!')
Expand All @@ -93,15 +113,18 @@ def get_answer_from_llama_web(messages, urls):
documents = get_documents_from_urls(combained_urls)
logging.info(documents)
index = GPTSimpleVectorIndex(documents)
logging.info(f"=====> Save index to disk path: {index_cache_web_dir + index_file_name}")
logging.info(
f"=====> Save index to disk path: {index_cache_web_dir + index_file_name}")
index.save_to_disk(index_cache_web_dir + index_file_name)
return index.query(dialog_messages, llm_predictor=llm_predictor, text_qa_template=QUESTION_ANSWER_PROMPT)


def get_index_name_from_file(file: str):
file_md5_with_extension = file.replace(index_cache_file_dir, '')
file_md5 = file_md5_with_extension.split('.')[0]
return file_md5 + '.json'


def get_answer_from_llama_file(messages, file):
dialog_messages = format_dialog_messages(messages)
logging.info('=====> Use llama file with chatGPT to answer!')
Expand All @@ -112,6 +135,54 @@ def get_answer_from_llama_file(messages, file):
logging.info(f"=====> Build index from file!")
documents = SimpleDirectoryReader(input_files=[file]).load_data()
index = GPTSimpleVectorIndex(documents)
logging.info(f"=====> Save index to disk path: {index_cache_file_dir + index_name}")
logging.info(
f"=====> Save index to disk path: {index_cache_file_dir + index_name}")
index.save_to_disk(index_cache_file_dir + index_name)
return index.query(dialog_messages, llm_predictor=llm_predictor, text_qa_template=QUESTION_ANSWER_PROMPT)


def get_text_from_whisper(voice_file_path):
with open(voice_file_path, "rb") as f:
transcript = openai.Audio.transcribe("whisper-1", f)
return transcript.text

def convert_to_ssml(text):
chinese_text = ''
english_text = ''
for sentence in text.split('.'):
if '。' in sentence:
chinese_text += sentence.strip() + '。'
else:
english_text += sentence.strip() + '. '

ssml = '<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="zh-CN">'
if chinese_text:
ssml += f'<voice name="zh-CN-XiaoxiaoNeural">{chinese_text}</voice>'
if english_text:
ssml += f'<voice name="en-US-JennyNeural">{english_text}</voice>'
ssml += '</speak>'

return ssml

def get_voice_file_from_text(text):
speech_config = SpeechConfig(subscription=SPEECH_KEY, region=SPEECH_REGION)
speech_config.set_speech_synthesis_output_format(
SpeechSynthesisOutputFormat.Audio16Khz32KBitRateMonoMp3)
speech_config.speech_synthesis_language = "zh-CN"
file_name = f"{index_cache_voice_dir}{uuid.uuid4()}.mp3"
file_config = AudioOutputConfig(filename=file_name)
synthesizer = SpeechSynthesizer(
speech_config=speech_config, audio_config=file_config)
ssml = convert_to_ssml(text)
result = synthesizer.speak_ssml_async(ssml).get()
if result.reason == ResultReason.SynthesizingAudioCompleted:
logging.info("Speech synthesized for text [{}], and the audio was saved to [{}]".format(
text, file_name))
elif result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details
logging.info("Speech synthesis canceled: {}".format(
cancellation_details.reason))
if cancellation_details.reason == CancellationReason.Error:
logging.error("Error details: {}".format(
cancellation_details.error_details))
return file_name
30 changes: 24 additions & 6 deletions app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from slack_bolt.adapter.flask import SlackRequestHandler
import concurrent.futures
from app.daily_hot_news import *
from app.gpt import get_answer_from_chatGPT, get_answer_from_llama_file, get_answer_from_llama_web, index_cache_file_dir
from app.gpt import get_answer_from_chatGPT, get_answer_from_llama_file, get_answer_from_llama_web, get_text_from_whisper, get_voice_file_from_text, index_cache_file_dir
from app.slash_command import register_slack_slash_commands
from app.util import md5

Expand Down Expand Up @@ -104,7 +104,8 @@ def extract_urls_from_event(event):

whitelist_file = "app/data//vip_whitelist.txt"

filetype_extension_allowed = ['epub', 'pdf', 'text', 'docx', 'markdown']
filetype_extension_allowed = ['epub', 'pdf', 'text', 'docx', 'markdown', 'm4a']
filetype_voice_extension_allowed = ['m4a']

def is_authorized(user_id: str) -> bool:
with open(whitelist_file, "r") as f:
Expand All @@ -115,6 +116,9 @@ def dialog_context_keep_latest(dialog_texts, max_length=1):
dialog_texts = dialog_texts[-max_length:]
return dialog_texts

def format_dialog_text(text, voicemessage=None):
return insert_space(text.replace("<@U04TCNR9MNF>", "")) + ('\n' + voicemessage if voicemessage else '')

@slack_app.event("app_mention")
def handle_mentions(event, say, logger):
logger.info(event)
Expand All @@ -123,6 +127,7 @@ def handle_mentions(event, say, logger):
thread_ts = event["ts"]

file_md5_name = None
voicemessage = None

if event.get('files'):
if not is_authorized(event['user']):
Expand All @@ -149,16 +154,19 @@ def handle_mentions(event, say, logger):
if not os.path.exists(file_md5_name):
logger.info(f'=====> Rename file to {file_md5_name}')
os.rename(temp_file_filename, file_md5_name)
if filetype in filetype_voice_extension_allowed:
voicemessage = get_text_from_whisper(file_md5_name)

parent_thread_ts = event["thread_ts"] if "thread_ts" in event else thread_ts
if parent_thread_ts not in thread_message_history:
thread_message_history[parent_thread_ts] = { 'dialog_texts': [], 'context_urls': set(), 'file': None}

if "text" in event:
update_thread_history(parent_thread_ts, 'User: %s' % insert_space(event["text"].replace('<@U04TCNR9MNF>', '')), extract_urls_from_event(event))
update_thread_history(parent_thread_ts, f'User: {format_dialog_text(event["text"], voicemessage)}', extract_urls_from_event(event))

if file_md5_name is not None:
update_thread_history(parent_thread_ts, None, None, file_md5_name)
if not voicemessage:
update_thread_history(parent_thread_ts, None, None, file_md5_name)

urls = thread_message_history[parent_thread_ts]['context_urls']
file = thread_message_history[parent_thread_ts]['file']
Expand All @@ -170,7 +178,7 @@ def handle_mentions(event, say, logger):
# if it can get the context_str, then put this prompt into the thread_message_history to provide more context to the chatGPT
if file is not None:
future = executor.submit(get_answer_from_llama_file, dialog_context_keep_latest(thread_message_history[parent_thread_ts]['dialog_texts']), file)
elif len(urls) > 0: # if this conversation has urls, use llama with all urls in this thread
elif len(urls) > 0:
future = executor.submit(get_answer_from_llama_web, thread_message_history[parent_thread_ts]['dialog_texts'], list(urls))
else:
future = executor.submit(get_answer_from_chatGPT, thread_message_history[parent_thread_ts]['dialog_texts'])
Expand All @@ -179,7 +187,17 @@ def handle_mentions(event, say, logger):
gpt_response = future.result(timeout=300)
update_thread_history(parent_thread_ts, 'AI: %s' % insert_space(f'{gpt_response}'))
logger.info(gpt_response)
say(f'<@{user}>, {gpt_response}', thread_ts=thread_ts)
if voicemessage is None:
say(f'<@{user}>, {gpt_response}', thread_ts=thread_ts)
else:
voice_file_path = get_voice_file_from_text(gpt_response)
with open(voice_file_path, 'rb') as f:
contents = f.read()
say(f'<@{user}>', thread_ts=thread_ts, files=[{
"filename": f"gpt_response_{thread_ts}.mp3",
"content": contents.decode('utf-8'),
"type": "audio/mpeg"
}])
except concurrent.futures.TimeoutError:
future.cancel()
err_msg = 'Task timedout(5m) and was canceled.'
Expand Down
Loading

0 comments on commit f0b1c00

Please sign in to comment.