Skip to content

Commit

Permalink
Merge branch 'Josh-XT:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
CRCODE22 authored Jan 24, 2024
2 parents 7186c25 + 5ed0e4f commit 42a2b32
Showing 1 changed file with 35 additions and 122 deletions.
157 changes: 35 additions & 122 deletions agixt/extensions/voice_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,8 @@ def __init__(self, WHISPER_MODEL="base.en", **kwargs):
self.tts_command = "Speak with TTS with Alltalk Text to Speech"

self.commands = {
"Chat with Voice": self.chat_with_voice,
"Prompt with Voice": self.prompt_with_voice,
"Command with Voice": self.command_with_voice,
"Transcribe WAV Audio": self.transcribe_wav_audio,
"Transcribe M4A Audio": self.transcribe_m4a_audio,
"Transcribe WEBM Audio": self.transcribe_webm_audio,
"Translate Text to Speech": self.text_to_speech,
}
self.conversation_name = f"Voice Chat with {self.agent_name}"
Expand Down Expand Up @@ -104,32 +101,6 @@ def __init__(self, WHISPER_MODEL="base.en", **kwargs):
)
open(self.model_path, "wb").write(r.content)

async def convert_m4a_to_wav(
self, base64_audio: str, filename: str = "recording.wav"
):
# Convert the base64 audio to a 16k WAV format
audio_data = base64.b64decode(base64_audio)
audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="m4a")
audio_segment = audio_segment.set_frame_rate(16000)
file_path = os.path.join(os.getcwd(), "WORKSPACE", filename)
audio_segment.export(file_path, format="wav")
with open(file_path, "rb") as f:
audio = f.read()
return f"{base64.b64encode(audio).decode('utf-8')}"

async def convert_webm_to_wav(
self, base64_audio: str, filename: str = "recording.wav"
):
# Convert the base64 audio to a 16k WAV format
audio_data = base64.b64decode(base64_audio)
audio_segment = AudioSegment.from_file(io.BytesIO(audio_data), format="webm")
audio_segment = audio_segment.set_frame_rate(16000)
file_path = os.path.join(os.getcwd(), "WORKSPACE", filename)
audio_segment.export(file_path, format="wav")
with open(file_path, "rb") as f:
audio = f.read()
return f"{base64.b64encode(audio).decode('utf-8')}"

async def transcribe_audio_from_file(self, filename: str = "recording.wav"):
w = Whisper(model_path=self.model_path)
file_path = os.path.join(os.getcwd(), "WORKSPACE", filename)
Expand All @@ -138,35 +109,6 @@ async def transcribe_audio_from_file(self, filename: str = "recording.wav"):
w.transcribe(file_path)
return w.output()

async def transcribe_wav_audio(
self,
base64_audio: str,
):
filename = f"{uuid.uuid4().hex}.wav"
# Write the base64 audio to a file.
with open(os.path.join(os.getcwd(), "WORKSPACE", filename), "wb") as f:
f.write(base64.b64decode(base64_audio))
# Transcribe the audio to text.
user_input = await self.transcribe_audio_from_file(filename=filename)
user_input.replace("[BLANK_AUDIO]", "")
os.remove(os.path.join(os.getcwd(), "WORKSPACE", filename))
return user_input

async def transcribe_m4a_audio(
self,
base64_audio: str,
):
# Convert from M4A to WAV
filename = f"{uuid.uuid4().hex}.wav"
user_audio = await self.convert_m4a_to_wav(
base64_audio=base64_audio, filename=filename
)
# Transcribe the audio to text.
user_input = await self.transcribe_audio_from_file(filename=filename)
user_input.replace("[BLANK_AUDIO]", "")
os.remove(os.path.join(os.getcwd(), "WORKSPACE", filename))
return user_input

async def text_to_speech(self, text: str):
# Get the audio response from the TTS engine and return it.
audio_response = self.ApiClient.execute_command(
Expand All @@ -176,56 +118,23 @@ async def text_to_speech(self, text: str):
)
return f"{audio_response}"

async def transcribe_webm_audio(
self,
base64_audio: str,
):
# Convert from WEBM to WAV
filename = f"{uuid.uuid4().hex}.wav"
user_audio = await self.convert_webm_to_wav(
base64_audio=base64_audio, filename=filename
)
# Transcribe the audio to text.
user_input = await self.transcribe_audio_from_file(filename=filename)
user_input.replace("[BLANK_AUDIO]", "")
os.remove(os.path.join(os.getcwd(), "WORKSPACE", filename))
return user_input

async def get_wav_audio(
self,
base64_audio,
audio_format="m4a",
):
async def get_user_input(self, base64_audio, audio_format="m4a"):
filename = f"{uuid.uuid4().hex}.wav"
if audio_format.lower() == "webm":
user_audio = await self.convert_webm_to_wav(
base64_audio=base64_audio, filename=filename
)
elif audio_format.lower() == "m4a":
user_audio = await self.convert_m4a_to_wav(
base64_audio=base64_audio, filename=filename
if audio_format.lower() != "wav":
audio_data = base64.b64decode(base64_audio)
audio_segment = AudioSegment.from_file(
io.BytesIO(audio_data), format=audio_format.lower()
)
audio_segment = audio_segment.set_frame_rate(16000)
file_path = os.path.join(os.getcwd(), "WORKSPACE", filename)
audio_segment.export(file_path, format="wav")
with open(file_path, "rb") as f:
audio = f.read()
return f"{base64.b64encode(audio).decode('utf-8')}"
else:
user_audio = base64_audio
return user_audio

async def chat_with_voice(
self,
base64_audio,
audio_format="m4a",
prompt_name="Custom Input",
prompt_args={
"context_results": 6,
"inject_memories_from_collection_number": 0,
},
):
filename = f"{uuid.uuid4().hex}.wav"
user_audio = await self.get_wav_audio(
base64_audio=base64_audio, audio_format=audio_format, filename=filename
)
# Transcribe the audio to text.
user_input = await self.transcribe_audio_from_file(filename=filename)
prompt_args["user_input"] = user_input
user_message = f"{user_input}\n#GENERATED_AUDIO:{user_audio}"
log_interaction(
agent_name=self.agent_name,
Expand All @@ -235,14 +144,32 @@ async def chat_with_voice(
user=self.user,
)
logging.info(f"[Whisper]: Transcribed User Input: {user_input}")
# Send the transcribed text to the agent.
return user_input

async def prompt_with_voice(
self,
base64_audio,
audio_format="m4a",
audio_variable="user_input",
prompt_name="Custom Input",
prompt_args={
"context_results": 6,
"inject_memories_from_collection_number": 0,
},
tts=False,
):
user_input = await self.get_user_input(
base64_audio=base64_audio, audio_format=audio_format
)
prompt_args[audio_variable] = user_input
text_response = self.ApiClient.prompt_agent(
agent_name=self.agent_name,
prompt_name=prompt_name,
prompt_args=prompt_args,
)
logging.info(f"[Whisper]: Text Response from LLM: {text_response}")
return self.text_to_speech(text=text_response)
if str(tts).lower() == "true":
return self.text_to_speech(text=text_response)
return f"{text_response}"

async def command_with_voice(
self,
Expand All @@ -253,30 +180,16 @@ async def command_with_voice(
command_args={"input": "Voice transcription from user"},
tts=False,
):
filename = f"{uuid.uuid4().hex}.wav"
user_audio = await self.get_wav_audio(
base64_audio=base64_audio, audio_format=audio_format, filename=filename
user_input = await self.get_user_input(
base64_audio=base64_audio, audio_format=audio_format
)
# Transcribe the audio to text.
user_input = await self.transcribe_audio_from_file(filename=filename)
command_args[audio_variable] = user_input
user_message = f"{user_input}\n#GENERATED_AUDIO:{user_audio}"
log_interaction(
agent_name=self.agent_name,
conversation_name=self.conversation_name,
role="USER",
message=user_message,
user=self.user,
)
logging.info(f"[Whisper]: Transcribed User Input: {user_input}")
# Send the transcribed text to the agent.
text_response = self.ApiClient.execute_command(
agent_name=self.agent_name,
command_name=command_name,
command_args=command_args,
conversation_name="AGiXT Terminal",
)
logging.info(f"[Whisper]: Text Response from LLM: {text_response}")
if str(tts).lower() == "true":
return self.text_to_speech(text=text_response)
return f"{text_response}"

0 comments on commit 42a2b32

Please sign in to comment.