diff --git a/pyproject.toml b/pyproject.toml index 63a50fac7..42adf2099 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "tzdata == 2023.3", "rapidocr-onnxruntime == 1.3.8", "stripe == 7.3.0", + "openai-whisper >= 20231117", ] dynamic = ["version"] diff --git a/src/interface/desktop/assets/icons/microphone-solid.svg b/src/interface/desktop/assets/icons/microphone-solid.svg new file mode 100644 index 000000000..3fc4b91d2 --- /dev/null +++ b/src/interface/desktop/assets/icons/microphone-solid.svg @@ -0,0 +1 @@ + diff --git a/src/interface/desktop/assets/icons/stop-solid.svg b/src/interface/desktop/assets/icons/stop-solid.svg new file mode 100644 index 000000000..a9aaba284 --- /dev/null +++ b/src/interface/desktop/assets/icons/stop-solid.svg @@ -0,0 +1,37 @@ + + + + + + + diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index a436b5934..ae17c6372 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -516,6 +516,18 @@ } } + function flashStatusInChatInput(message) { + // Get chat input element and original placeholder + let chatInput = document.getElementById("chat-input"); + let originalPlaceholder = chatInput.placeholder; + // Set placeholder to message + chatInput.placeholder = message; + // Reset placeholder after 2 seconds + setTimeout(() => { + chatInput.placeholder = originalPlaceholder; + }, 2000); + } + async function clearConversationHistory() { let chatInput = document.getElementById("chat-input"); let originalPlaceholder = chatInput.placeholder; @@ -530,17 +542,71 @@ .then(data => { chatBody.innerHTML = ""; loadChat(); - chatInput.placeholder = "Cleared conversation history"; + flashStatusInChatInput("🗑 Cleared conversation history"); }) .catch(err => { - chatInput.placeholder = "Failed to clear conversation history"; + flashStatusInChatInput("⛔️ Failed to clear conversation history"); }) - .finally(() => { - setTimeout(() => { - chatInput.placeholder = originalPlaceholder; - }, 2000); + } + + let mediaRecorder; + async function speechToText() { + const speakButtonImg = document.getElementById('speak-button-img'); + const chatInput = document.getElementById('chat-input'); + + const hostURL = await window.hostURLAPI.getURL(); + let url = `${hostURL}/api/transcribe?client=desktop`; + const khojToken = await window.tokenAPI.getToken(); + const headers = { 'Authorization': `Bearer ${khojToken}` }; + + const sendToServer = (audioBlob) => { + const formData = new FormData(); + formData.append('file', audioBlob); + + fetch(url, { method: 'POST', body: formData, headers}) + .then(response => response.ok ? response.json() : Promise.reject(response)) + .then(data => { chatInput.value += data.text; }) + .catch(err => { + err.status == 422 + ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + : flashStatusInChatInput("⛔️ Failed to transcribe audio") + }); + }; + + const handleRecording = (stream) => { + const audioChunks = []; + const recordingConfig = { mimeType: 'audio/webm' }; + mediaRecorder = new MediaRecorder(stream, recordingConfig); + + mediaRecorder.addEventListener("dataavailable", function(event) { + if (event.data.size > 0) audioChunks.push(event.data); + }); + + mediaRecorder.addEventListener("stop", function() { + const audioBlob = new Blob(audioChunks, { type: 'audio/webm' }); + sendToServer(audioBlob); }); + + mediaRecorder.start(); + speakButtonImg.src = './assets/icons/stop-solid.svg'; + speakButtonImg.alt = 'Stop Transcription'; + }; + + // Toggle recording + if (!mediaRecorder || mediaRecorder.state === 'inactive') { + navigator.mediaDevices + .getUserMedia({ audio: true }) + .then(handleRecording) + .catch((e) => { + flashStatusInChatInput("⛔️ Failed to access microphone"); + }); + } else if (mediaRecorder.state === 'recording') { + mediaRecorder.stop(); + speakButtonImg.src = './assets/icons/microphone-solid.svg'; + speakButtonImg.alt = 'Transcribe'; + } } +
@@ -569,8 +635,11 @@
- +
@@ -620,7 +689,6 @@ .chat-message.you { margin-right: auto; text-align: right; - white-space: pre-line; } /* basic style chat message text */ .chat-message-text { @@ -637,7 +705,6 @@ color: var(--primary-inverse); background: var(--primary); margin-left: auto; - white-space: pre-line; } /* Spinner symbol when the chat message is loading */ .spinner { @@ -694,7 +761,7 @@ } #input-row { display: grid; - grid-template-columns: auto 32px; + grid-template-columns: auto 32px 32px; grid-column-gap: 10px; grid-row-gap: 10px; background: #f9fafc diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index fc6d5a488..16c5614fa 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -1,4 +1,4 @@ -import { App, Modal, request, setIcon } from 'obsidian'; +import { App, Modal, RequestUrlParam, request, requestUrl, setIcon } from 'obsidian'; import { KhojSetting } from 'src/settings'; import fetch from "node-fetch"; @@ -51,6 +51,16 @@ export class KhojChatModal extends Modal { }) chatInput.addEventListener('change', (event) => { this.result = (event.target).value }); + let transcribe = inputRow.createEl("button", { + text: "Transcribe", + attr: { + id: "khoj-transcribe", + class: "khoj-transcribe khoj-input-row-button", + }, + }) + transcribe.addEventListener('click', async (_) => { await this.speechToText() }); + setIcon(transcribe, "mic"); + let clearChat = inputRow.createEl("button", { text: "Clear History", attr: { @@ -205,9 +215,19 @@ export class KhojChatModal extends Modal { } } - async clearConversationHistory() { + flashStatusInChatInput(message: string) { + // Get chat input element and original placeholder let chatInput = this.contentEl.getElementsByClassName("khoj-chat-input")[0]; let originalPlaceholder = chatInput.placeholder; + // Set placeholder to message + chatInput.placeholder = message; + // Reset placeholder after 2 seconds + setTimeout(() => { + chatInput.placeholder = originalPlaceholder; + }, 2000); + } + + async clearConversationHistory() { let chatBody = this.contentEl.getElementsByClassName("khoj-chat-body")[0]; let response = await request({ @@ -224,15 +244,84 @@ export class KhojChatModal extends Modal { // If conversation history is cleared successfully, clear chat logs from modal chatBody.innerHTML = ""; await this.getChatHistory(); - chatInput.placeholder = result.message; + this.flashStatusInChatInput(result.message); } } catch (err) { - chatInput.placeholder = "Failed to clear conversation history"; - } finally { - // Reset to original placeholder text after some time - setTimeout(() => { - chatInput.placeholder = originalPlaceholder; - }, 2000); + this.flashStatusInChatInput("Failed to clear conversation history"); + } + } + + mediaRecorder: MediaRecorder | undefined; + async speechToText() { + const transcribeButton = this.contentEl.getElementsByClassName("khoj-transcribe")[0]; + const chatInput = this.contentEl.getElementsByClassName("khoj-chat-input")[0]; + + const generateRequestBody = async (audioBlob: Blob, boundary_string: string) => { + const boundary = `------${boundary_string}`; + const chunks: ArrayBuffer[] = []; + + chunks.push(new TextEncoder().encode(`${boundary}\r\n`)); + chunks.push(new TextEncoder().encode(`Content-Disposition: form-data; name="file"; filename="blob"\r\nContent-Type: "application/octet-stream"\r\n\r\n`)); + chunks.push(await audioBlob.arrayBuffer()); + chunks.push(new TextEncoder().encode('\r\n')); + + await Promise.all(chunks); + chunks.push(new TextEncoder().encode(`${boundary}--\r\n`)); + return await new Blob(chunks).arrayBuffer(); + }; + + const sendToServer = async (audioBlob: Blob) => { + const boundary_string = `Boundary${Math.random().toString(36).slice(2)}`; + const requestBody = await generateRequestBody(audioBlob, boundary_string); + + const response = await requestUrl({ + url: `${this.setting.khojUrl}/api/transcribe?client=obsidian`, + method: 'POST', + headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` }, + contentType: `multipart/form-data; boundary=----${boundary_string}`, + body: requestBody, + }); + + // Parse response from Khoj backend + if (response.status === 200) { + console.log(response); + chatInput.value += response.json.text; + } else if (response.status === 422) { + throw new Error("⛔️ Failed to transcribe audio"); + } else { + throw new Error("⛔️ Configure speech-to-text model on server."); + } + }; + + const handleRecording = (stream: MediaStream) => { + const audioChunks: Blob[] = []; + const recordingConfig = { mimeType: 'audio/webm' }; + this.mediaRecorder = new MediaRecorder(stream, recordingConfig); + + this.mediaRecorder.addEventListener("dataavailable", function(event) { + if (event.data.size > 0) audioChunks.push(event.data); + }); + + this.mediaRecorder.addEventListener("stop", async function() { + const audioBlob = new Blob(audioChunks, { type: 'audio/webm' }); + await sendToServer(audioBlob); + }); + + this.mediaRecorder.start(); + setIcon(transcribeButton, "mic-off"); + }; + + // Toggle recording + if (!this.mediaRecorder || this.mediaRecorder.state === 'inactive') { + navigator.mediaDevices + .getUserMedia({ audio: true }) + .then(handleRecording) + .catch((e) => { + this.flashStatusInChatInput("⛔️ Failed to access microphone"); + }); + } else if (this.mediaRecorder.state === 'recording') { + this.mediaRecorder.stop(); + setIcon(transcribeButton, "mic"); } } } diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index 95a304f1b..ff2dee8a8 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -112,7 +112,7 @@ If your plugin does not need CSS, delete this file. } .khoj-input-row { display: grid; - grid-template-columns: auto 32px; + grid-template-columns: auto 32px 32px; grid-column-gap: 10px; grid-row-gap: 10px; background: var(--background-primary); diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index de5f1b5d4..eb143ab66 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -26,6 +26,7 @@ OfflineChatProcessorConversationConfig, OpenAIProcessorConversationConfig, SearchModelConfig, + SpeechToTextModelOptions, Subscription, UserConversationConfig, OpenAIProcessorConversationConfig, @@ -344,6 +345,10 @@ async def get_openai_chat(): async def get_openai_chat_config(): return await OpenAIProcessorConversationConfig.objects.filter().afirst() + @staticmethod + async def get_speech_to_text_config(): + return await SpeechToTextModelOptions.objects.filter().afirst() + @staticmethod async def aget_conversation_starters(user: KhojUser): all_questions = [] diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index e1095eced..2213fb6ef 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -9,6 +9,7 @@ OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, SearchModelConfig, + SpeechToTextModelOptions, Subscription, ReflectiveQuestion, ) @@ -16,6 +17,7 @@ admin.site.register(KhojUser, UserAdmin) admin.site.register(ChatModelOptions) +admin.site.register(SpeechToTextModelOptions) admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(SearchModelConfig) diff --git a/src/khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py b/src/khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py new file mode 100644 index 000000000..373377915 --- /dev/null +++ b/src/khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py @@ -0,0 +1,42 @@ +# Generated by Django 4.2.7 on 2023-11-26 13:54 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0020_reflectivequestion"), + ] + + operations = [ + migrations.CreateModel( + name="SpeechToTextModelOptions", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("model_name", models.CharField(default="base", max_length=200)), + ( + "model_type", + models.CharField( + choices=[("openai", "Openai"), ("offline", "Offline")], default="offline", max_length=200 + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.AlterField( + model_name="chatmodeloptions", + name="chat_model", + field=models.CharField(default="mistral-7b-instruct-v0.1.Q4_0.gguf", max_length=200), + ), + migrations.AlterField( + model_name="chatmodeloptions", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("offline", "Offline")], default="offline", max_length=200 + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index b0463df82..82348fbe6 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -120,6 +120,15 @@ class OfflineChatProcessorConversationConfig(BaseModel): enabled = models.BooleanField(default=False) +class SpeechToTextModelOptions(BaseModel): + class ModelType(models.TextChoices): + OPENAI = "openai" + OFFLINE = "offline" + + model_name = models.CharField(max_length=200, default="base") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + + class ChatModelOptions(BaseModel): class ModelType(models.TextChoices): OPENAI = "openai" @@ -127,8 +136,8 @@ class ModelType(models.TextChoices): max_prompt_size = models.IntegerField(default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) - chat_model = models.CharField(max_length=200, default=None, null=True, blank=True) - model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) class UserConversationConfig(BaseModel): diff --git a/src/khoj/interface/web/assets/icons/microphone-solid.svg b/src/khoj/interface/web/assets/icons/microphone-solid.svg new file mode 100644 index 000000000..3fc4b91d2 --- /dev/null +++ b/src/khoj/interface/web/assets/icons/microphone-solid.svg @@ -0,0 +1 @@ + diff --git a/src/khoj/interface/web/assets/icons/stop-solid.svg b/src/khoj/interface/web/assets/icons/stop-solid.svg new file mode 100644 index 000000000..a9aaba284 --- /dev/null +++ b/src/khoj/interface/web/assets/icons/stop-solid.svg @@ -0,0 +1,37 @@ + + + + + + + diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 8d2accc41..256193a78 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -543,6 +543,18 @@ } } + function flashStatusInChatInput(message) { + // Get chat input element and original placeholder + let chatInput = document.getElementById("chat-input"); + let originalPlaceholder = chatInput.placeholder; + // Set placeholder to message + chatInput.placeholder = message; + // Reset placeholder after 2 seconds + setTimeout(() => { + chatInput.placeholder = originalPlaceholder; + }, 2000); + } + function clearConversationHistory() { let chatInput = document.getElementById("chat-input"); let originalPlaceholder = chatInput.placeholder; @@ -553,17 +565,65 @@ .then(data => { chatBody.innerHTML = ""; loadChat(); - chatInput.placeholder = "Cleared conversation history"; + flashStatusInChatInput("🗑 Cleared conversation history"); }) .catch(err => { - chatInput.placeholder = "Failed to clear conversation history"; - }) - .finally(() => { - setTimeout(() => { - chatInput.placeholder = originalPlaceholder; - }, 2000); + flashStatusInChatInput("⛔️ Failed to clear conversation history"); }); } + + let mediaRecorder; + function speechToText() { + const speakButtonImg = document.getElementById('speak-button-img'); + const chatInput = document.getElementById('chat-input'); + + const sendToServer = (audioBlob) => { + const formData = new FormData(); + formData.append('file', audioBlob); + + fetch('/api/transcribe?client=web', { method: 'POST', body: formData }) + .then(response => response.ok ? response.json() : Promise.reject(response)) + .then(data => { chatInput.value += data.text; }) + .catch(err => { + err.status == 422 + ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + : flashStatusInChatInput("⛔️ Failed to transcribe audio") + }); + }; + + const handleRecording = (stream) => { + const audioChunks = []; + const recordingConfig = { mimeType: 'audio/webm' }; + mediaRecorder = new MediaRecorder(stream, recordingConfig); + + mediaRecorder.addEventListener("dataavailable", function(event) { + if (event.data.size > 0) audioChunks.push(event.data); + }); + + mediaRecorder.addEventListener("stop", function() { + const audioBlob = new Blob(audioChunks, { type: 'audio/webm' }); + sendToServer(audioBlob); + }); + + mediaRecorder.start(); + speakButtonImg.src = '/static/assets/icons/stop-solid.svg'; + speakButtonImg.alt = 'Stop Transcription'; + }; + + // Toggle recording + if (!mediaRecorder || mediaRecorder.state === 'inactive') { + navigator.mediaDevices + .getUserMedia({ audio: true }) + .then(handleRecording) + .catch((e) => { + flashStatusInChatInput("⛔️ Failed to access microphone"); + }); + } else if (mediaRecorder.state === 'recording') { + mediaRecorder.stop(); + speakButtonImg.src = '/static/assets/icons/microphone-solid.svg'; + speakButtonImg.alt = 'Transcribe'; + } + }
@@ -584,8 +644,11 @@
+
@@ -749,7 +812,6 @@ .chat-message.you { margin-right: auto; text-align: right; - white-space: pre-line; } /* basic style chat message text */ .chat-message-text { @@ -766,7 +828,6 @@ color: var(--primary-inverse); background: var(--primary); margin-left: auto; - white-space: pre-line; } /* Spinner symbol when the chat message is loading */ .spinner { @@ -815,6 +876,7 @@ #chat-footer { padding: 0; + margin: 8px; display: grid; grid-template-columns: minmax(70px, 100%); grid-column-gap: 10px; @@ -822,7 +884,7 @@ } #input-row { display: grid; - grid-template-columns: auto 32px; + grid-template-columns: auto 32px 32px; grid-column-gap: 10px; grid-row-gap: 10px; background: #f9fafc diff --git a/src/khoj/processor/conversation/gpt4all/__init__.py b/src/khoj/processor/conversation/offline/__init__.py similarity index 100% rename from src/khoj/processor/conversation/gpt4all/__init__.py rename to src/khoj/processor/conversation/offline/__init__.py diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py similarity index 100% rename from src/khoj/processor/conversation/gpt4all/chat_model.py rename to src/khoj/processor/conversation/offline/chat_model.py diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/offline/utils.py similarity index 100% rename from src/khoj/processor/conversation/gpt4all/utils.py rename to src/khoj/processor/conversation/offline/utils.py diff --git a/src/khoj/processor/conversation/offline/whisper.py b/src/khoj/processor/conversation/offline/whisper.py new file mode 100644 index 000000000..56d2aaf5c --- /dev/null +++ b/src/khoj/processor/conversation/offline/whisper.py @@ -0,0 +1,17 @@ +# External Packages +from asgiref.sync import sync_to_async +import whisper + +# Internal Packages +from khoj.utils import state + + +async def transcribe_audio_offline(audio_filename: str, model: str) -> str: + """ + Transcribe audio file offline using Whisper + """ + # Send the audio data to the Whisper API + if not state.whisper_model: + state.whisper_model = whisper.load_model(model) + response = await sync_to_async(state.whisper_model.transcribe)(audio_filename) + return response["text"] diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py new file mode 100644 index 000000000..72834d921 --- /dev/null +++ b/src/khoj/processor/conversation/openai/whisper.py @@ -0,0 +1,15 @@ +# Standard Packages +from io import BufferedReader + +# External Packages +from asgiref.sync import sync_to_async +import openai + + +async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: + """ + Transcribe audio file using Whisper model via OpenAI's API + """ + # Send the audio data to the Whisper API + response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key) + return response["text"] diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ef822687f..3fd2285d8 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -3,13 +3,14 @@ import json import logging import math +import os import time from typing import Any, Dict, List, Optional, Union - -from asgiref.sync import sync_to_async +import uuid # External Packages -from fastapi import APIRouter, Depends, Header, HTTPException, Request +from asgiref.sync import sync_to_async +from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse from starlette.authentication import requires @@ -29,8 +30,10 @@ LocalPlaintextConfig, NotionConfig, ) -from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline +from khoj.processor.conversation.offline.chat_model import extract_questions_offline +from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions +from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.tools.online_search import search_with_google from khoj.routers.helpers import ( @@ -585,6 +588,59 @@ async def chat_options( return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200) +@api.post("/transcribe") +@requires(["authenticated"]) +async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)): + user: KhojUser = request.user.object + audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm" + user_message: str = None + + # If the file is too large, return an unprocessable entity error + if file.size > 10 * 1024 * 1024: + logger.warning(f"Audio file too large to transcribe. Audio file size: {file.size}. Exceeds 10Mb limit.") + return Response(content="Audio size larger than 10Mb limit", status_code=422) + + # Transcribe the audio from the request + try: + # Store the audio from the request in a temporary file + audio_data = await file.read() + with open(audio_filename, "wb") as audio_file_writer: + audio_file_writer.write(audio_data) + audio_file = open(audio_filename, "rb") + + # Send the audio data to the Whisper API + speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() + openai_chat_config = await ConversationAdapters.get_openai_chat_config() + if not speech_to_text_config: + # If the user has not configured a speech to text model, return an unprocessable entity error + status_code = 422 + elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: + api_key = openai_chat_config.api_key + speech2text_model = speech_to_text_config.model_name + user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) + elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: + speech2text_model = speech_to_text_config.model_name + user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) + finally: + # Close and Delete the temporary audio file + audio_file.close() + os.remove(audio_filename) + + if user_message is None: + return Response(status_code=status_code or 500) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="transcribe", + **common.__dict__, + ) + + # Return the spoken text + content = json.dumps({"text": user_message}) + return Response(content=content, media_type="application/json", status_code=200) + + @api.get("/chat", response_class=Response) @requires(["authenticated"]) async def chat( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c6fcb4364..e1ab05b51 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -15,7 +15,7 @@ from khoj.database.adapters import ConversationAdapters from khoj.database.models import KhojUser, Subscription from khoj.processor.conversation import prompts -from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline +from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 7795d695d..abda12b6f 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -11,7 +11,7 @@ import torch # Internal Packages -from khoj.processor.conversation.gpt4all.utils import download_model +from khoj.processor.conversation.offline.utils import download_model logger = logging.getLogger(__name__) @@ -80,7 +80,7 @@ class GPT4AllProcessorConfig: class GPT4AllProcessorModel: def __init__( self, - chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", + chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", ): self.chat_model = chat_model self.loaded_model = None diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index ffc4d47eb..313b18fcd 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -6,6 +6,7 @@ OfflineChatProcessorConversationConfig, OpenAIProcessorConversationConfig, ChatModelOptions, + SpeechToTextModelOptions, ) from khoj.utils.constants import default_offline_chat_model, default_online_chat_model @@ -73,10 +74,9 @@ def _create_chat_configuration(): except ModuleNotFoundError as e: logger.warning("Offline models are not supported on this device.") - use_openai_model = input("Use OpenAI chat model? (y/n): ") - + use_openai_model = input("Use OpenAI models? (y/n): ") if use_openai_model == "y": - logger.info("🗣️ Setting up OpenAI chat model") + logger.info("🗣️ Setting up your OpenAI configuration") api_key = input("Enter your OpenAI API key: ") OpenAIProcessorConversationConfig.objects.create(api_key=api_key) @@ -94,7 +94,34 @@ def _create_chat_configuration(): chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_prompt_size=max_tokens ) - logger.info("🗣️ Chat model configuration complete") + default_speech2text_model = "whisper-1" + openai_speech2text_model = input( + f"Enter the OpenAI speech to text model you want to use (default: {default_speech2text_model}): " + ) + openai_speech2text_model = openai_speech2text_model or default_speech2text_model + SpeechToTextModelOptions.objects.create( + model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI + ) + + if use_offline_model == "y" or use_openai_model == "y": + logger.info("🗣️ Chat model configuration complete") + + use_offline_speech2text_model = input("Use offline speech to text model? (y/n): ") + if use_offline_speech2text_model == "y": + logger.info("🗣️ Setting up offline speech to text model") + # Delete any existing speech to text model options. There can only be one. + SpeechToTextModelOptions.objects.all().delete() + + default_offline_speech2text_model = "base" + offline_speech2text_model = input( + f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): " + ) + offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model + SpeechToTextModelOptions.objects.create( + model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE + ) + + logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}") admin_user = KhojUser.objects.filter(is_staff=True).first() if admin_user is None: diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 91f5f0cee..b54cf4b39 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -7,6 +7,7 @@ # External Packages from pathlib import Path from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel +from whisper import Whisper # Internal Packages from khoj.utils import config as utils_config @@ -21,6 +22,7 @@ cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() gpt4all_processor_config: GPT4AllProcessorModel = None +whisper_model: Whisper = None config_file: Path = None verbose: int = 0 host: str = None diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index 782b54f20..7b59e1e3f 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -19,8 +19,8 @@ print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") # Internal Packages -from khoj.processor.conversation.gpt4all.chat_model import converse_offline, extract_questions_offline, filter_questions -from khoj.processor.conversation.gpt4all.utils import download_model +from khoj.processor.conversation.offline.chat_model import converse_offline, extract_questions_offline, filter_questions +from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import message_to_log