diff --git a/pyproject.toml b/pyproject.toml
index 63a50fac7..42adf2099 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -75,6 +75,7 @@ dependencies = [
"tzdata == 2023.3",
"rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0",
+ "openai-whisper >= 20231117",
]
dynamic = ["version"]
diff --git a/src/interface/desktop/assets/icons/microphone-solid.svg b/src/interface/desktop/assets/icons/microphone-solid.svg
new file mode 100644
index 000000000..3fc4b91d2
--- /dev/null
+++ b/src/interface/desktop/assets/icons/microphone-solid.svg
@@ -0,0 +1 @@
+
diff --git a/src/interface/desktop/assets/icons/stop-solid.svg b/src/interface/desktop/assets/icons/stop-solid.svg
new file mode 100644
index 000000000..a9aaba284
--- /dev/null
+++ b/src/interface/desktop/assets/icons/stop-solid.svg
@@ -0,0 +1,37 @@
+
+
diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index a436b5934..ae17c6372 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -516,6 +516,18 @@
}
}
+ function flashStatusInChatInput(message) {
+ // Get chat input element and original placeholder
+ let chatInput = document.getElementById("chat-input");
+ let originalPlaceholder = chatInput.placeholder;
+ // Set placeholder to message
+ chatInput.placeholder = message;
+ // Reset placeholder after 2 seconds
+ setTimeout(() => {
+ chatInput.placeholder = originalPlaceholder;
+ }, 2000);
+ }
+
async function clearConversationHistory() {
let chatInput = document.getElementById("chat-input");
let originalPlaceholder = chatInput.placeholder;
@@ -530,17 +542,71 @@
.then(data => {
chatBody.innerHTML = "";
loadChat();
- chatInput.placeholder = "Cleared conversation history";
+ flashStatusInChatInput("🗑 Cleared conversation history");
})
.catch(err => {
- chatInput.placeholder = "Failed to clear conversation history";
+ flashStatusInChatInput("⛔️ Failed to clear conversation history");
})
- .finally(() => {
- setTimeout(() => {
- chatInput.placeholder = originalPlaceholder;
- }, 2000);
+ }
+
+ let mediaRecorder;
+ async function speechToText() {
+ const speakButtonImg = document.getElementById('speak-button-img');
+ const chatInput = document.getElementById('chat-input');
+
+ const hostURL = await window.hostURLAPI.getURL();
+ let url = `${hostURL}/api/transcribe?client=desktop`;
+ const khojToken = await window.tokenAPI.getToken();
+ const headers = { 'Authorization': `Bearer ${khojToken}` };
+
+ const sendToServer = (audioBlob) => {
+ const formData = new FormData();
+ formData.append('file', audioBlob);
+
+ fetch(url, { method: 'POST', body: formData, headers})
+ .then(response => response.ok ? response.json() : Promise.reject(response))
+ .then(data => { chatInput.value += data.text; })
+ .catch(err => {
+ err.status == 422
+ ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
+ : flashStatusInChatInput("⛔️ Failed to transcribe audio")
+ });
+ };
+
+ const handleRecording = (stream) => {
+ const audioChunks = [];
+ const recordingConfig = { mimeType: 'audio/webm' };
+ mediaRecorder = new MediaRecorder(stream, recordingConfig);
+
+ mediaRecorder.addEventListener("dataavailable", function(event) {
+ if (event.data.size > 0) audioChunks.push(event.data);
+ });
+
+ mediaRecorder.addEventListener("stop", function() {
+ const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
+ sendToServer(audioBlob);
});
+
+ mediaRecorder.start();
+ speakButtonImg.src = './assets/icons/stop-solid.svg';
+ speakButtonImg.alt = 'Stop Transcription';
+ };
+
+ // Toggle recording
+ if (!mediaRecorder || mediaRecorder.state === 'inactive') {
+ navigator.mediaDevices
+ .getUserMedia({ audio: true })
+ .then(handleRecording)
+ .catch((e) => {
+ flashStatusInChatInput("⛔️ Failed to access microphone");
+ });
+ } else if (mediaRecorder.state === 'recording') {
+ mediaRecorder.stop();
+ speakButtonImg.src = './assets/icons/microphone-solid.svg';
+ speakButtonImg.alt = 'Transcribe';
+ }
}
+
@@ -620,7 +689,6 @@
.chat-message.you {
margin-right: auto;
text-align: right;
- white-space: pre-line;
}
/* basic style chat message text */
.chat-message-text {
@@ -637,7 +705,6 @@
color: var(--primary-inverse);
background: var(--primary);
margin-left: auto;
- white-space: pre-line;
}
/* Spinner symbol when the chat message is loading */
.spinner {
@@ -694,7 +761,7 @@
}
#input-row {
display: grid;
- grid-template-columns: auto 32px;
+ grid-template-columns: auto 32px 32px;
grid-column-gap: 10px;
grid-row-gap: 10px;
background: #f9fafc
diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts
index fc6d5a488..16c5614fa 100644
--- a/src/interface/obsidian/src/chat_modal.ts
+++ b/src/interface/obsidian/src/chat_modal.ts
@@ -1,4 +1,4 @@
-import { App, Modal, request, setIcon } from 'obsidian';
+import { App, Modal, RequestUrlParam, request, requestUrl, setIcon } from 'obsidian';
import { KhojSetting } from 'src/settings';
import fetch from "node-fetch";
@@ -51,6 +51,16 @@ export class KhojChatModal extends Modal {
})
chatInput.addEventListener('change', (event) => { this.result = (event.target).value });
+ let transcribe = inputRow.createEl("button", {
+ text: "Transcribe",
+ attr: {
+ id: "khoj-transcribe",
+ class: "khoj-transcribe khoj-input-row-button",
+ },
+ })
+ transcribe.addEventListener('click', async (_) => { await this.speechToText() });
+ setIcon(transcribe, "mic");
+
let clearChat = inputRow.createEl("button", {
text: "Clear History",
attr: {
@@ -205,9 +215,19 @@ export class KhojChatModal extends Modal {
}
}
- async clearConversationHistory() {
+ flashStatusInChatInput(message: string) {
+ // Get chat input element and original placeholder
let chatInput = this.contentEl.getElementsByClassName("khoj-chat-input")[0];
let originalPlaceholder = chatInput.placeholder;
+ // Set placeholder to message
+ chatInput.placeholder = message;
+ // Reset placeholder after 2 seconds
+ setTimeout(() => {
+ chatInput.placeholder = originalPlaceholder;
+ }, 2000);
+ }
+
+ async clearConversationHistory() {
let chatBody = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
let response = await request({
@@ -224,15 +244,84 @@ export class KhojChatModal extends Modal {
// If conversation history is cleared successfully, clear chat logs from modal
chatBody.innerHTML = "";
await this.getChatHistory();
- chatInput.placeholder = result.message;
+ this.flashStatusInChatInput(result.message);
}
} catch (err) {
- chatInput.placeholder = "Failed to clear conversation history";
- } finally {
- // Reset to original placeholder text after some time
- setTimeout(() => {
- chatInput.placeholder = originalPlaceholder;
- }, 2000);
+ this.flashStatusInChatInput("Failed to clear conversation history");
+ }
+ }
+
+ mediaRecorder: MediaRecorder | undefined;
+ async speechToText() {
+ const transcribeButton = this.contentEl.getElementsByClassName("khoj-transcribe")[0];
+ const chatInput = this.contentEl.getElementsByClassName("khoj-chat-input")[0];
+
+ const generateRequestBody = async (audioBlob: Blob, boundary_string: string) => {
+ const boundary = `------${boundary_string}`;
+ const chunks: ArrayBuffer[] = [];
+
+ chunks.push(new TextEncoder().encode(`${boundary}\r\n`));
+ chunks.push(new TextEncoder().encode(`Content-Disposition: form-data; name="file"; filename="blob"\r\nContent-Type: "application/octet-stream"\r\n\r\n`));
+ chunks.push(await audioBlob.arrayBuffer());
+ chunks.push(new TextEncoder().encode('\r\n'));
+
+ await Promise.all(chunks);
+ chunks.push(new TextEncoder().encode(`${boundary}--\r\n`));
+ return await new Blob(chunks).arrayBuffer();
+ };
+
+ const sendToServer = async (audioBlob: Blob) => {
+ const boundary_string = `Boundary${Math.random().toString(36).slice(2)}`;
+ const requestBody = await generateRequestBody(audioBlob, boundary_string);
+
+ const response = await requestUrl({
+ url: `${this.setting.khojUrl}/api/transcribe?client=obsidian`,
+ method: 'POST',
+ headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` },
+ contentType: `multipart/form-data; boundary=----${boundary_string}`,
+ body: requestBody,
+ });
+
+ // Parse response from Khoj backend
+ if (response.status === 200) {
+ console.log(response);
+ chatInput.value += response.json.text;
+ } else if (response.status === 422) {
+ throw new Error("⛔️ Failed to transcribe audio");
+ } else {
+ throw new Error("⛔️ Configure speech-to-text model on server.");
+ }
+ };
+
+ const handleRecording = (stream: MediaStream) => {
+ const audioChunks: Blob[] = [];
+ const recordingConfig = { mimeType: 'audio/webm' };
+ this.mediaRecorder = new MediaRecorder(stream, recordingConfig);
+
+ this.mediaRecorder.addEventListener("dataavailable", function(event) {
+ if (event.data.size > 0) audioChunks.push(event.data);
+ });
+
+ this.mediaRecorder.addEventListener("stop", async function() {
+ const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
+ await sendToServer(audioBlob);
+ });
+
+ this.mediaRecorder.start();
+ setIcon(transcribeButton, "mic-off");
+ };
+
+ // Toggle recording
+ if (!this.mediaRecorder || this.mediaRecorder.state === 'inactive') {
+ navigator.mediaDevices
+ .getUserMedia({ audio: true })
+ .then(handleRecording)
+ .catch((e) => {
+ this.flashStatusInChatInput("⛔️ Failed to access microphone");
+ });
+ } else if (this.mediaRecorder.state === 'recording') {
+ this.mediaRecorder.stop();
+ setIcon(transcribeButton, "mic");
}
}
}
diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css
index 95a304f1b..ff2dee8a8 100644
--- a/src/interface/obsidian/styles.css
+++ b/src/interface/obsidian/styles.css
@@ -112,7 +112,7 @@ If your plugin does not need CSS, delete this file.
}
.khoj-input-row {
display: grid;
- grid-template-columns: auto 32px;
+ grid-template-columns: auto 32px 32px;
grid-column-gap: 10px;
grid-row-gap: 10px;
background: var(--background-primary);
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index de5f1b5d4..eb143ab66 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -26,6 +26,7 @@
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
SearchModelConfig,
+ SpeechToTextModelOptions,
Subscription,
UserConversationConfig,
OpenAIProcessorConversationConfig,
@@ -344,6 +345,10 @@ async def get_openai_chat():
async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
+ @staticmethod
+ async def get_speech_to_text_config():
+ return await SpeechToTextModelOptions.objects.filter().afirst()
+
@staticmethod
async def aget_conversation_starters(user: KhojUser):
all_questions = []
diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py
index e1095eced..2213fb6ef 100644
--- a/src/khoj/database/admin.py
+++ b/src/khoj/database/admin.py
@@ -9,6 +9,7 @@
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
SearchModelConfig,
+ SpeechToTextModelOptions,
Subscription,
ReflectiveQuestion,
)
@@ -16,6 +17,7 @@
admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions)
+admin.site.register(SpeechToTextModelOptions)
admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModelConfig)
diff --git a/src/khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py b/src/khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py
new file mode 100644
index 000000000..373377915
--- /dev/null
+++ b/src/khoj/database/migrations/0021_speechtotextmodeloptions_and_more.py
@@ -0,0 +1,42 @@
+# Generated by Django 4.2.7 on 2023-11-26 13:54
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0020_reflectivequestion"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="SpeechToTextModelOptions",
+ fields=[
+ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("model_name", models.CharField(default="base", max_length=200)),
+ (
+ "model_type",
+ models.CharField(
+ choices=[("openai", "Openai"), ("offline", "Offline")], default="offline", max_length=200
+ ),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ migrations.AlterField(
+ model_name="chatmodeloptions",
+ name="chat_model",
+ field=models.CharField(default="mistral-7b-instruct-v0.1.Q4_0.gguf", max_length=200),
+ ),
+ migrations.AlterField(
+ model_name="chatmodeloptions",
+ name="model_type",
+ field=models.CharField(
+ choices=[("openai", "Openai"), ("offline", "Offline")], default="offline", max_length=200
+ ),
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index b0463df82..82348fbe6 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -120,6 +120,15 @@ class OfflineChatProcessorConversationConfig(BaseModel):
enabled = models.BooleanField(default=False)
+class SpeechToTextModelOptions(BaseModel):
+ class ModelType(models.TextChoices):
+ OPENAI = "openai"
+ OFFLINE = "offline"
+
+ model_name = models.CharField(max_length=200, default="base")
+ model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
+
+
class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
@@ -127,8 +136,8 @@ class ModelType(models.TextChoices):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
- chat_model = models.CharField(max_length=200, default=None, null=True, blank=True)
- model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
+ chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
+ model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class UserConversationConfig(BaseModel):
diff --git a/src/khoj/interface/web/assets/icons/microphone-solid.svg b/src/khoj/interface/web/assets/icons/microphone-solid.svg
new file mode 100644
index 000000000..3fc4b91d2
--- /dev/null
+++ b/src/khoj/interface/web/assets/icons/microphone-solid.svg
@@ -0,0 +1 @@
+
diff --git a/src/khoj/interface/web/assets/icons/stop-solid.svg b/src/khoj/interface/web/assets/icons/stop-solid.svg
new file mode 100644
index 000000000..a9aaba284
--- /dev/null
+++ b/src/khoj/interface/web/assets/icons/stop-solid.svg
@@ -0,0 +1,37 @@
+
+
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index 8d2accc41..256193a78 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -543,6 +543,18 @@
}
}
+ function flashStatusInChatInput(message) {
+ // Get chat input element and original placeholder
+ let chatInput = document.getElementById("chat-input");
+ let originalPlaceholder = chatInput.placeholder;
+ // Set placeholder to message
+ chatInput.placeholder = message;
+ // Reset placeholder after 2 seconds
+ setTimeout(() => {
+ chatInput.placeholder = originalPlaceholder;
+ }, 2000);
+ }
+
function clearConversationHistory() {
let chatInput = document.getElementById("chat-input");
let originalPlaceholder = chatInput.placeholder;
@@ -553,17 +565,65 @@
.then(data => {
chatBody.innerHTML = "";
loadChat();
- chatInput.placeholder = "Cleared conversation history";
+ flashStatusInChatInput("🗑 Cleared conversation history");
})
.catch(err => {
- chatInput.placeholder = "Failed to clear conversation history";
- })
- .finally(() => {
- setTimeout(() => {
- chatInput.placeholder = originalPlaceholder;
- }, 2000);
+ flashStatusInChatInput("⛔️ Failed to clear conversation history");
});
}
+
+ let mediaRecorder;
+ function speechToText() {
+ const speakButtonImg = document.getElementById('speak-button-img');
+ const chatInput = document.getElementById('chat-input');
+
+ const sendToServer = (audioBlob) => {
+ const formData = new FormData();
+ formData.append('file', audioBlob);
+
+ fetch('/api/transcribe?client=web', { method: 'POST', body: formData })
+ .then(response => response.ok ? response.json() : Promise.reject(response))
+ .then(data => { chatInput.value += data.text; })
+ .catch(err => {
+ err.status == 422
+ ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
+ : flashStatusInChatInput("⛔️ Failed to transcribe audio")
+ });
+ };
+
+ const handleRecording = (stream) => {
+ const audioChunks = [];
+ const recordingConfig = { mimeType: 'audio/webm' };
+ mediaRecorder = new MediaRecorder(stream, recordingConfig);
+
+ mediaRecorder.addEventListener("dataavailable", function(event) {
+ if (event.data.size > 0) audioChunks.push(event.data);
+ });
+
+ mediaRecorder.addEventListener("stop", function() {
+ const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
+ sendToServer(audioBlob);
+ });
+
+ mediaRecorder.start();
+ speakButtonImg.src = '/static/assets/icons/stop-solid.svg';
+ speakButtonImg.alt = 'Stop Transcription';
+ };
+
+ // Toggle recording
+ if (!mediaRecorder || mediaRecorder.state === 'inactive') {
+ navigator.mediaDevices
+ .getUserMedia({ audio: true })
+ .then(handleRecording)
+ .catch((e) => {
+ flashStatusInChatInput("⛔️ Failed to access microphone");
+ });
+ } else if (mediaRecorder.state === 'recording') {
+ mediaRecorder.stop();
+ speakButtonImg.src = '/static/assets/icons/microphone-solid.svg';
+ speakButtonImg.alt = 'Transcribe';
+ }
+ }
@@ -749,7 +812,6 @@
.chat-message.you {
margin-right: auto;
text-align: right;
- white-space: pre-line;
}
/* basic style chat message text */
.chat-message-text {
@@ -766,7 +828,6 @@
color: var(--primary-inverse);
background: var(--primary);
margin-left: auto;
- white-space: pre-line;
}
/* Spinner symbol when the chat message is loading */
.spinner {
@@ -815,6 +876,7 @@
#chat-footer {
padding: 0;
+ margin: 8px;
display: grid;
grid-template-columns: minmax(70px, 100%);
grid-column-gap: 10px;
@@ -822,7 +884,7 @@
}
#input-row {
display: grid;
- grid-template-columns: auto 32px;
+ grid-template-columns: auto 32px 32px;
grid-column-gap: 10px;
grid-row-gap: 10px;
background: #f9fafc
diff --git a/src/khoj/processor/conversation/gpt4all/__init__.py b/src/khoj/processor/conversation/offline/__init__.py
similarity index 100%
rename from src/khoj/processor/conversation/gpt4all/__init__.py
rename to src/khoj/processor/conversation/offline/__init__.py
diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py
similarity index 100%
rename from src/khoj/processor/conversation/gpt4all/chat_model.py
rename to src/khoj/processor/conversation/offline/chat_model.py
diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/offline/utils.py
similarity index 100%
rename from src/khoj/processor/conversation/gpt4all/utils.py
rename to src/khoj/processor/conversation/offline/utils.py
diff --git a/src/khoj/processor/conversation/offline/whisper.py b/src/khoj/processor/conversation/offline/whisper.py
new file mode 100644
index 000000000..56d2aaf5c
--- /dev/null
+++ b/src/khoj/processor/conversation/offline/whisper.py
@@ -0,0 +1,17 @@
+# External Packages
+from asgiref.sync import sync_to_async
+import whisper
+
+# Internal Packages
+from khoj.utils import state
+
+
+async def transcribe_audio_offline(audio_filename: str, model: str) -> str:
+ """
+ Transcribe audio file offline using Whisper
+ """
+ # Send the audio data to the Whisper API
+ if not state.whisper_model:
+ state.whisper_model = whisper.load_model(model)
+ response = await sync_to_async(state.whisper_model.transcribe)(audio_filename)
+ return response["text"]
diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py
new file mode 100644
index 000000000..72834d921
--- /dev/null
+++ b/src/khoj/processor/conversation/openai/whisper.py
@@ -0,0 +1,15 @@
+# Standard Packages
+from io import BufferedReader
+
+# External Packages
+from asgiref.sync import sync_to_async
+import openai
+
+
+async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
+ """
+ Transcribe audio file using Whisper model via OpenAI's API
+ """
+ # Send the audio data to the Whisper API
+ response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
+ return response["text"]
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index ef822687f..3fd2285d8 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -3,13 +3,14 @@
import json
import logging
import math
+import os
import time
from typing import Any, Dict, List, Optional, Union
-
-from asgiref.sync import sync_to_async
+import uuid
# External Packages
-from fastapi import APIRouter, Depends, Header, HTTPException, Request
+from asgiref.sync import sync_to_async
+from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
@@ -29,8 +30,10 @@
LocalPlaintextConfig,
NotionConfig,
)
-from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
+from khoj.processor.conversation.offline.chat_model import extract_questions_offline
+from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
+from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import (
@@ -585,6 +588,59 @@ async def chat_options(
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
+@api.post("/transcribe")
+@requires(["authenticated"])
+async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
+ user: KhojUser = request.user.object
+ audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
+ user_message: str = None
+
+ # If the file is too large, return an unprocessable entity error
+ if file.size > 10 * 1024 * 1024:
+ logger.warning(f"Audio file too large to transcribe. Audio file size: {file.size}. Exceeds 10Mb limit.")
+ return Response(content="Audio size larger than 10Mb limit", status_code=422)
+
+ # Transcribe the audio from the request
+ try:
+ # Store the audio from the request in a temporary file
+ audio_data = await file.read()
+ with open(audio_filename, "wb") as audio_file_writer:
+ audio_file_writer.write(audio_data)
+ audio_file = open(audio_filename, "rb")
+
+ # Send the audio data to the Whisper API
+ speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
+ openai_chat_config = await ConversationAdapters.get_openai_chat_config()
+ if not speech_to_text_config:
+ # If the user has not configured a speech to text model, return an unprocessable entity error
+ status_code = 422
+ elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
+ api_key = openai_chat_config.api_key
+ speech2text_model = speech_to_text_config.model_name
+ user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
+ elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
+ speech2text_model = speech_to_text_config.model_name
+ user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
+ finally:
+ # Close and Delete the temporary audio file
+ audio_file.close()
+ os.remove(audio_filename)
+
+ if user_message is None:
+ return Response(status_code=status_code or 500)
+
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="transcribe",
+ **common.__dict__,
+ )
+
+ # Return the spoken text
+ content = json.dumps({"text": user_message})
+ return Response(content=content, media_type="application/json", status_code=200)
+
+
@api.get("/chat", response_class=Response)
@requires(["authenticated"])
async def chat(
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index c6fcb4364..e1ab05b51 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -15,7 +15,7 @@
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, Subscription
from khoj.processor.conversation import prompts
-from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
+from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 7795d695d..abda12b6f 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -11,7 +11,7 @@
import torch
# Internal Packages
-from khoj.processor.conversation.gpt4all.utils import download_model
+from khoj.processor.conversation.offline.utils import download_model
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ class GPT4AllProcessorConfig:
class GPT4AllProcessorModel:
def __init__(
self,
- chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
+ chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
):
self.chat_model = chat_model
self.loaded_model = None
diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py
index ffc4d47eb..313b18fcd 100644
--- a/src/khoj/utils/initialization.py
+++ b/src/khoj/utils/initialization.py
@@ -6,6 +6,7 @@
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
ChatModelOptions,
+ SpeechToTextModelOptions,
)
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
@@ -73,10 +74,9 @@ def _create_chat_configuration():
except ModuleNotFoundError as e:
logger.warning("Offline models are not supported on this device.")
- use_openai_model = input("Use OpenAI chat model? (y/n): ")
-
+ use_openai_model = input("Use OpenAI models? (y/n): ")
if use_openai_model == "y":
- logger.info("🗣️ Setting up OpenAI chat model")
+ logger.info("🗣️ Setting up your OpenAI configuration")
api_key = input("Enter your OpenAI API key: ")
OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
@@ -94,7 +94,34 @@ def _create_chat_configuration():
chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_prompt_size=max_tokens
)
- logger.info("🗣️ Chat model configuration complete")
+ default_speech2text_model = "whisper-1"
+ openai_speech2text_model = input(
+ f"Enter the OpenAI speech to text model you want to use (default: {default_speech2text_model}): "
+ )
+ openai_speech2text_model = openai_speech2text_model or default_speech2text_model
+ SpeechToTextModelOptions.objects.create(
+ model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
+ )
+
+ if use_offline_model == "y" or use_openai_model == "y":
+ logger.info("🗣️ Chat model configuration complete")
+
+ use_offline_speech2text_model = input("Use offline speech to text model? (y/n): ")
+ if use_offline_speech2text_model == "y":
+ logger.info("🗣️ Setting up offline speech to text model")
+ # Delete any existing speech to text model options. There can only be one.
+ SpeechToTextModelOptions.objects.all().delete()
+
+ default_offline_speech2text_model = "base"
+ offline_speech2text_model = input(
+ f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): "
+ )
+ offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model
+ SpeechToTextModelOptions.objects.create(
+ model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE
+ )
+
+ logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
admin_user = KhojUser.objects.filter(is_staff=True).first()
if admin_user is None:
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index 91f5f0cee..b54cf4b39 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -7,6 +7,7 @@
# External Packages
from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
+from whisper import Whisper
# Internal Packages
from khoj.utils import config as utils_config
@@ -21,6 +22,7 @@
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
gpt4all_processor_config: GPT4AllProcessorModel = None
+whisper_model: Whisper = None
config_file: Path = None
verbose: int = 0
host: str = None
diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py
index 782b54f20..7b59e1e3f 100644
--- a/tests/test_gpt4all_chat_actors.py
+++ b/tests/test_gpt4all_chat_actors.py
@@ -19,8 +19,8 @@
print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
# Internal Packages
-from khoj.processor.conversation.gpt4all.chat_model import converse_offline, extract_questions_offline, filter_questions
-from khoj.processor.conversation.gpt4all.utils import download_model
+from khoj.processor.conversation.offline.chat_model import converse_offline, extract_questions_offline, filter_questions
+from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log