Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PARKED] Create /chat endpoint #412

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,6 @@
"line_number": 63
}
],
"core_backend/app/question_answer/schemas.py": [
{
"type": "Secret Keyword",
"filename": "core_backend/app/question_answer/schemas.py",
"hashed_secret": "5b8b7a620e54e681c584f5b5c89152773c10c253",
"is_verified": false,
"line_number": 67
}
],
"core_backend/migrations/versions/2023_09_16_c5a948963236_create_query_table.py": [
{
"type": "Hex High Entropy String",
Expand Down Expand Up @@ -430,7 +421,7 @@
"filename": "core_backend/tests/api/test_dashboard_overview.py",
"hashed_secret": "6367c48dd193d56ea7b0baad25b19455e529f5ee",
"is_verified": false,
"line_number": 291
"line_number": 290
}
],
"core_backend/tests/api/test_dashboard_performance.py": [
Expand Down Expand Up @@ -590,5 +581,5 @@
}
]
},
"generated_at": "2024-08-23T09:41:17Z"
"generated_at": "2024-08-26T14:03:01Z"
}
10 changes: 5 additions & 5 deletions core_backend/add_dummy_data_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def create_query_record(dt: datetime, session: Session) -> QueryDB:

query_db = QueryDB(
user_id=_USER_ID,
session_id=1,
session_id="1",
feedback_secret_key="abc123", # pragma: allowlist secret
query_text="test query",
query_generate_llm_response=False,
Expand All @@ -198,7 +198,7 @@ def create_query_record(dt: datetime, session: Session) -> QueryDB:


def create_response_feedback_record(
dt: datetime, query_id: int, session_id: int, is_negative: bool, session: Session
dt: datetime, query_id: int, session_id: str, is_negative: bool, session: Session
) -> None:
"""Create a feedback record for a given datetime.

Expand All @@ -209,7 +209,7 @@ def create_response_feedback_record(
query_id
The ID of the query record.
session_id
The ID of the session record.
The ID of the session record -- uuid
is_negative
Specifies whether the feedback is negative.
session
Expand All @@ -235,7 +235,7 @@ def create_response_feedback_record(
def create_content_feedback_record(
dt: datetime,
query_id: int,
session_id: int,
session_id: str,
is_negative: bool,
session: Session,
) -> None:
Expand All @@ -248,7 +248,7 @@ def create_content_feedback_record(
query_id
The ID of the query record.
session_id
The ID of the session record.
The ID of the session record. (uuid)
is_negative
Specifies whether the content feedback is negative.
session
Expand Down
16 changes: 9 additions & 7 deletions core_backend/app/contents/schemas.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from datetime import datetime
from typing import Annotated, List
from typing import List

from pydantic import BaseModel, ConfigDict, Field, StringConstraints
from pydantic import BaseModel, ConfigDict, Field


class ContentCreate(BaseModel):
"""
Pydantic model for content creation request
"""

content_title: Annotated[str, StringConstraints(max_length=150)] = Field(
content_title: str = Field(
max_length=150,
examples=["Example Content Title"],
)
content_text: Annotated[str, StringConstraints(max_length=2000)] = Field(
examples=["This is an example content."]
content_text: str = Field(
max_length=2000,
examples=["This is an example content."],
)
content_tags: list = Field(default=[], examples=[[1, 4]])
content_metadata: dict = Field(default={}, examples=[{"key": "optional_value"}])
content_tags: list = Field(default=[])
content_metadata: dict = Field(default={})
is_archived: bool = False

model_config = ConfigDict(
Expand Down
13 changes: 8 additions & 5 deletions core_backend/app/llm_call/llm_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Augmented Generation (RAG).
"""

from typing import Optional

from pydantic import ValidationError

from ..config import LITELLM_MODEL_GENERATION
Expand All @@ -15,23 +13,27 @@


async def get_llm_rag_answer(
question: str,
question: str | list[dict[str, str]],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To undo

context: str,
original_language: IdentifiedLanguage,
metadata: Optional[dict] = None,
metadata: dict | None = None,
chat_history: list[dict[str, str]] | None = None,
) -> RAG:
"""Get an answer from the LLM model using RAG.

Parameters
----------
question
The question to ask the LLM model.
The question to ask the LLM model, or list of chat history messages in the form
of {"content": str, "role": str}.
Comment on lines +27 to +28
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To undo

context
The context to provide to the LLM model.
response_language
The language of the response.
metadata
Additional metadata to provide to the LLM model.
chat_history
The previous chat history to provide to the LLM model if it exists.

Returns
-------
Expand All @@ -45,6 +47,7 @@ async def get_llm_rag_answer(
result = await _ask_llm_async(
user_message=question,
system_message=prompt,
chat_history=chat_history,
litellm_model=LITELLM_MODEL_GENERATION,
metadata=metadata,
json=True,
Expand Down
4 changes: 4 additions & 0 deletions core_backend/app/llm_call/process_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _process_identified_language_response(

error_response = QueryResponseError(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=response.llm_response,
search_results=response.search_results,
Expand Down Expand Up @@ -206,6 +207,7 @@ async def _translate_question(
else:
error_response = QueryResponseError(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=response.llm_response,
search_results=response.search_results,
Expand Down Expand Up @@ -275,6 +277,7 @@ async def _classify_safety(
else:
error_response = QueryResponseError(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=response.llm_response,
search_results=response.search_results,
Expand Down Expand Up @@ -352,6 +355,7 @@ async def _paraphrase_question(
else:
error_response = QueryResponseError(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=response.llm_response,
search_results=response.search_results,
Expand Down
41 changes: 7 additions & 34 deletions core_backend/app/llm_call/process_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,39 +45,7 @@ class AlignScoreData(TypedDict):
claim: str


def generate_llm_response__after(func: Callable) -> Callable:
"""
Decorator to generate the LLM response.

Only runs if the generate_llm_response flag is set to True.
Requires "search_results" and "original_language" in the response.
"""

@wraps(func)
async def wrapper(
query_refined: QueryRefined,
response: QueryResponse | QueryResponseError,
*args: Any,
**kwargs: Any,
) -> QueryResponse | QueryResponseError:
"""
Generate the LLM response
"""
response = await func(query_refined, response, *args, **kwargs)

if not query_refined.generate_llm_response:
return response

metadata = create_langfuse_metadata(
query_id=response.query_id, user_id=query_refined.user_id
)
response = await _generate_llm_response(query_refined, response, metadata)
return response

return wrapper


async def _generate_llm_response(
async def generate_llm_query_response(
query_refined: QueryRefined,
response: QueryResponse,
metadata: Optional[dict] = None,
Expand All @@ -99,12 +67,13 @@ async def _generate_llm_response(
return response

context = get_context_string_from_search_results(response.search_results)

rag_response = await get_llm_rag_answer(
# use the original query text
question=query_refined.query_text_original,
context=context,
original_language=query_refined.original_language,
metadata=metadata,
chat_history=response.chat_history,
)

if rag_response.answer != RAG_FAILURE_MESSAGE:
Expand All @@ -114,6 +83,7 @@ async def _generate_llm_response(
else:
response = QueryResponseError(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=None,
search_results=response.search_results,
Expand Down Expand Up @@ -219,6 +189,7 @@ async def _check_align_score(
)
response = QueryResponseError(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=None,
search_results=response.search_results,
Expand Down Expand Up @@ -311,6 +282,7 @@ async def wrapper(
)
response = QueryAudioResponse(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=response.llm_response,
search_results=response.search_results,
Expand Down Expand Up @@ -361,6 +333,7 @@ async def _generate_tts_response(
logger.error(f"Error generating TTS for query_id {response.query_id}: {e}")
return QueryResponseError(
query_id=response.query_id,
session_id=response.session_id,
feedback_secret_key=response.feedback_secret_key,
llm_response=response.llm_response,
search_results=response.search_results,
Expand Down
15 changes: 9 additions & 6 deletions core_backend/app/llm_call/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from litellm import acompletion

from ..config import LITELLM_API_KEY, LITELLM_ENDPOINT, LITELLM_MODEL_DEFAULT
Expand All @@ -9,11 +7,12 @@


async def _ask_llm_async(
user_message: str,
user_message: str | list[dict[str, str]],
system_message: str,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to undo

litellm_model: Optional[str] = LITELLM_MODEL_DEFAULT,
litellm_endpoint: Optional[str] = LITELLM_ENDPOINT,
metadata: Optional[dict] = None,
chat_history: list[dict[str, str]] | None = None,
litellm_model: str | None = LITELLM_MODEL_DEFAULT,
litellm_endpoint: str | None = LITELLM_ENDPOINT,
metadata: dict | None = None,
json: bool = False,
) -> str:
"""
Expand All @@ -36,6 +35,10 @@ async def _ask_llm_async(
"role": "user",
},
]

if chat_history is not None:
messages = messages[:1] + chat_history + messages[1:]

logger.info(f"LLM input: 'model': {litellm_model}, 'endpoint': {litellm_endpoint}")

llm_response_raw = await acompletion(
Expand Down
Loading
Loading