diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d98670ada..99cc75194 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,6 +12,8 @@ env: POSTGRES_USER: postgres-test-user POSTGRES_DB: postgres-test-db REDIS_HOST: redis://redis:6379 + ENABLE_VOICE_SEARCH: True + jobs: container-job: runs-on: ubuntu-20.04 @@ -55,6 +57,7 @@ jobs: env: PROMETHEUS_MULTIPROC_DIR: /tmp REDIS_HOST: ${{ env.REDIS_HOST }} + ENABLE_VOICE_SEARCH: ${{ env.ENABLE_VOICE_SEARCH }} run: | cd core_backend export POSTGRES_HOST=postgres POSTGRES_USER=$POSTGRES_USER \ diff --git a/Makefile b/Makefile index 6b900d013..9a2c94cdc 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ #!make +SHELL := /bin/bash PROJECT_NAME = aaq CONDA_ACTIVATE=source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate diff --git a/core_backend/Makefile b/core_backend/Makefile index 10153cb86..43d2b2acf 100644 --- a/core_backend/Makefile +++ b/core_backend/Makefile @@ -1,4 +1,5 @@ #!make +SHELL := /bin/bash .PHONY : tests @@ -49,4 +50,3 @@ teardown-redis-test: teardown-test-db: @docker stop testdb @docker rm testdb - diff --git a/core_backend/app/config.py b/core_backend/app/config.py index 2b53632ab..a858f8440 100644 --- a/core_backend/app/config.py +++ b/core_backend/app/config.py @@ -83,6 +83,8 @@ BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "") # Speech API +ENABLE_VOICE_SEARCH = os.getenv("ENABLE_VOICE_SEARCH", "False").lower() == "true" + CUSTOM_STT_ENDPOINT = os.environ.get("CUSTOM_STT_ENDPOINT", None) CUSTOM_TTS_ENDPOINT = os.environ.get("CUSTOM_TTS_ENDPOINT", None) diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 5777820d6..c1c5fb4a5 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -12,7 +12,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import authenticate_key, rate_limiter -from ..config import CUSTOM_STT_ENDPOINT, GCS_SPEECH_BUCKET, USE_CROSS_ENCODER +from ..config import ( + CUSTOM_STT_ENDPOINT, + ENABLE_VOICE_SEARCH, + GCS_SPEECH_BUCKET, + USE_CROSS_ENCODER, +) from ..contents.models import ( get_similar_content_async, increment_query_count, @@ -155,134 +160,137 @@ async def search( ) -@router.post( - "/voice-search", - response_model=QueryAudioResponse, - responses={ - status.HTTP_400_BAD_REQUEST: { - "model": QueryResponseError, - "description": "Bad Request", - }, - status.HTTP_500_INTERNAL_SERVER_ERROR: { - "model": QueryResponseError, - "description": "Internal Server Error", +if ENABLE_VOICE_SEARCH: + + @router.post( + "/voice-search", + response_model=QueryAudioResponse, + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": QueryResponseError, + "description": "Bad Request", + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: { + "model": QueryResponseError, + "description": "Internal Server Error", + }, }, - }, -) -async def voice_search( - file_url: str, - request: Request, - asession: AsyncSession = Depends(get_async_session), - user_db: UserDB = Depends(authenticate_key), -) -> QueryAudioResponse | JSONResponse: - """ - Endpoint to transcribe audio from a provided URL, - generate an LLM response, by default generate_tts is - set to true and return a public random URL of an audio - file containing the spoken version of the generated response. - """ - try: - file_stream, content_type, file_extension = await download_file_from_url( - file_url - ) + ) + async def voice_search( + file_url: str, + request: Request, + asession: AsyncSession = Depends(get_async_session), + user_db: UserDB = Depends(authenticate_key), + ) -> QueryAudioResponse | JSONResponse: + """ + Endpoint to transcribe audio from a provided URL, + generate an LLM response, by default generate_tts is + set to true and return a public random URL of an audio + file containing the spoken version of the generated response. + """ + try: + file_stream, content_type, file_extension = await download_file_from_url( + file_url + ) - unique_filename = generate_random_filename(file_extension) - destination_blob_name = f"stt-voice-notes/{unique_filename}" + unique_filename = generate_random_filename(file_extension) + destination_blob_name = f"stt-voice-notes/{unique_filename}" - await upload_file_to_gcs( - GCS_SPEECH_BUCKET, file_stream, destination_blob_name, content_type - ) - file_path = f"temp/{unique_filename}" - with open(file_path, "wb") as f: + await upload_file_to_gcs( + GCS_SPEECH_BUCKET, file_stream, destination_blob_name, content_type + ) + file_path = f"temp/{unique_filename}" + with open(file_path, "wb") as f: + file_stream.seek(0) + f.write(file_stream.read()) file_stream.seek(0) - f.write(file_stream.read()) - file_stream.seek(0) - if CUSTOM_STT_ENDPOINT is not None: - transcription = await post_to_speech_stt(file_path, CUSTOM_STT_ENDPOINT) - transcription_result = transcription["text"] + if CUSTOM_STT_ENDPOINT is not None: + transcription = await post_to_speech_stt(file_path, CUSTOM_STT_ENDPOINT) + transcription_result = transcription["text"] - else: - transcription_result = await transcribe_audio(file_path) + else: + transcription_result = await transcribe_audio(file_path) - user_query = QueryBase( - generate_llm_response=True, - query_text=transcription_result, - query_metadata={}, - ) - - ( - user_query_db, - user_query_refined_template, - response_template, - ) = await get_user_query_and_response( - user_id=user_db.user_id, - user_query=user_query, - asession=asession, - generate_tts=True, - ) + user_query = QueryBase( + generate_llm_response=True, + query_text=transcription_result, + query_metadata={}, + ) - response = await get_search_response( - query_refined=user_query_refined_template, - response=response_template, - user_id=user_db.user_id, - n_similar=int(N_TOP_CONTENT), - n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER), - asession=asession, - exclude_archived=True, - request=request, - ) + ( + user_query_db, + user_query_refined_template, + response_template, + ) = await get_user_query_and_response( + user_id=user_db.user_id, + user_query=user_query, + asession=asession, + generate_tts=True, + ) - if user_query.generate_llm_response: - response = await get_generation_response( + response = await get_search_response( query_refined=user_query_refined_template, - response=response, + response=response_template, + user_id=user_db.user_id, + n_similar=int(N_TOP_CONTENT), + n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER), + asession=asession, + exclude_archived=True, + request=request, ) - await save_query_response_to_db(user_query_db, response, asession) - await increment_query_count( - user_id=user_db.user_id, - contents=response.search_results, - asession=asession, - ) - await save_content_for_query_to_db( - user_id=user_db.user_id, - query_id=response.query_id, - session_id=user_query.session_id, - contents=response.search_results, - asession=asession, - ) + if user_query.generate_llm_response: + response = await get_generation_response( + query_refined=user_query_refined_template, + response=response, + ) + + await save_query_response_to_db(user_query_db, response, asession) + await increment_query_count( + user_id=user_db.user_id, + contents=response.search_results, + asession=asession, + ) + await save_content_for_query_to_db( + user_id=user_db.user_id, + query_id=response.query_id, + session_id=user_query.session_id, + contents=response.search_results, + asession=asession, + ) + + if os.path.exists(file_path): + os.remove(file_path) + file_stream.close() - if os.path.exists(file_path): - os.remove(file_path) - file_stream.close() + if type(response) is QueryAudioResponse: + return response - if type(response) is QueryAudioResponse: - return response + if type(response) is QueryResponseError: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response.model_dump(), + ) - if type(response) is QueryResponseError: return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump() + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"error": "Internal server error"}, ) - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"error": "Internal server error"}, - ) - - except ValueError as ve: - logger.error(f"ValueError: {str(ve)}") - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"error": f"Value error: {str(ve)}"}, - ) + except ValueError as ve: + logger.error(f"ValueError: {str(ve)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"error": f"Value error: {str(ve)}"}, + ) - except Exception as e: - logger.error(f"Unexpected error: {str(e)}") - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"error": "Internal server error"}, - ) + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"error": "Internal server error"}, + ) @identify_language__before diff --git a/core_backend/tests/api/test.env b/core_backend/tests/api/test.env index 3079fb64f..26cdf986a 100644 --- a/core_backend/tests/api/test.env +++ b/core_backend/tests/api/test.env @@ -13,3 +13,4 @@ ALIGN_SCORE_API="http://localhost:5002/alignscore_base" # if u want to try the tests for the external TTS and STT apis then comment this out CUSTOM_STT_ENDPOINT="http://localhost:8001/transcribe" CUSTOM_TTS_ENDPOINT="http://localhost:8001/synthesize" +ENABLE_VOICE_SEARCH=True diff --git a/deployment/docker-compose/template.core_backend.env b/deployment/docker-compose/template.core_backend.env index 276262a76..a91f0d6b1 100644 --- a/deployment/docker-compose/template.core_backend.env +++ b/deployment/docker-compose/template.core_backend.env @@ -37,6 +37,11 @@ LITELLM_ENDPOINT="http://localhost:4000" #PGVECTOR_VECTOR_SIZE=1024 #### Speech APIs ############################################################### +# This variable controls whether the voice search endpoint is active (set to true) or inactive (set to false). Default is false. +# If enabled, we default to using external services unless `CUSTOM_SPEECH_ENDPOINT` is set, in which case the custom hosted APIs will be used. +# ENABLE_VOICE_SEARCH=True + +# if TOGGLE_VOICE is set to 'Custom' then make sure to also set the Environment variables mentioned below # CUSTOM_STT_ENDPOINT=http://speech_service:8001/transcribe # CUSTOM_TTS_ENDPOINT=http://speech_service:8001/synthesize