Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding chat history to RAG app and refactor to better utilize LangChain #648

Open
wants to merge 61 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
dda40b9
Also introduced a basic session history mechanism in the browser to k…
alpha-amundson May 3, 2024
5cc85b9
tflint formatting fixes
alpha-amundson May 3, 2024
6898666
TPU Provisioner: JobSet related fixes (#645)
nstogner May 6, 2024
1d6c052
Updated image to use code in this branch
alpha-amundson May 6, 2024
981e777
making tflint happy
alpha-amundson May 6, 2024
d1d1211
Working on improvements for rag application (#731)
german-grandas Jul 12, 2024
5a16b54
Rag langchain chat history (#747)
german-grandas Jul 22, 2024
9e416c8
Rag langchain chat history (#755)
german-grandas Jul 29, 2024
e750d12
Fixing issues and updating chat history on frontend
german-grandas Jul 31, 2024
a000c46
Fixing files on working tree
german-grandas Jul 31, 2024
0d853ea
Ignoring test rag, to review how the rag application is working
german-grandas Aug 1, 2024
386c437
ignoring unit test to review cloud build process
german-grandas Aug 1, 2024
be1839d
refactoring cloud sql connection helper
german-grandas Aug 6, 2024
7f081ff
Merge branch 'main' into rag-langchain-chat-history
german-grandas Aug 6, 2024
35f67e4
Change TPU Metrics Source for Autoscaling (#770)
Bslabe123 Aug 8, 2024
0022053
Refactor: move workload identity service account out of kuberay-opera…
genlu2011 Aug 15, 2024
48f655b
updating branch
german-grandas Aug 20, 2024
a9895d6
fixing conflicts with remote branch
german-grandas Aug 20, 2024
cd95c98
fixing conflicts with remote branch
german-grandas Aug 20, 2024
bc8d745
fixing conflicts with remote branch
german-grandas Aug 20, 2024
e9beeef
fixing conflicts applying rebase
german-grandas Aug 20, 2024
eb9ab02
Updating files based on reviewer comments
german-grandas Aug 20, 2024
dff8d94
reverting change on cloudbuild.yaml file
german-grandas Aug 20, 2024
138920f
Reverting comment of line
german-grandas Aug 26, 2024
c8e5d35
Updating length of variable
german-grandas Aug 26, 2024
c437736
updating branch with main
german-grandas Aug 30, 2024
8b4f55d
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 2, 2024
4261818
Updating rag frontend image.
german-grandas Sep 4, 2024
a8258f1
updating rag frontend images with the latest changes
german-grandas Sep 9, 2024
4e1a4c5
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 9, 2024
863ee72
updating branch
german-grandas Sep 9, 2024
4f02546
Fixing issue with database connection
german-grandas Sep 9, 2024
324abf7
Merge branch 'rag-langchain-chat-history' of github.com:GoogleCloudPl…
german-grandas Sep 9, 2024
88ee300
Updating Rag application test.
german-grandas Sep 9, 2024
e209b46
Merge branch 'rag-langchain-chat-history' of https://github.com/Googl…
german-grandas Sep 9, 2024
cf0a447
Adding exceptions to test
german-grandas Sep 9, 2024
bf2f990
Fixing bug on unit test
german-grandas Sep 9, 2024
74b6e9d
fixing unit test
german-grandas Sep 9, 2024
e94cab0
updating notebook to use the PostgresVectorStore instead of the custo…
german-grandas Sep 10, 2024
329417b
fixing issue with notebook
german-grandas Sep 10, 2024
f1bf05a
Fixing issue with missing environment varibles on notebook
german-grandas Sep 11, 2024
88fe07d
Refactoring example notebooks to handle new cloudsql vector store
german-grandas Sep 11, 2024
e38a101
Adding missing package to notebook
german-grandas Sep 11, 2024
799c8db
Creating a notebook for testing rag with a sample of the data
german-grandas Sep 12, 2024
14ff203
updating notebook to test rag
german-grandas Sep 12, 2024
a73f987
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 12, 2024
c01cff6
Reverting changes on files, updating database model on notebook
german-grandas Sep 12, 2024
0a04782
Fixing name with column on notebook query
german-grandas Sep 12, 2024
68a11f7
Merge branch 'main' of github.com:GoogleCloudPlatform/ai-on-gke into …
german-grandas Sep 13, 2024
d872679
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 16, 2024
e9a79ce
Merge branch 'rag-langchain-chat-history' of github.com:GoogleCloudPl…
german-grandas Sep 16, 2024
7fe461c
resolving conflicts
german-grandas Sep 17, 2024
c489d73
Delete applications/rag/example_notebooks/ingest_database.ipynb
german-grandas Sep 17, 2024
aa44fcb
updating Embedding model with missing column
german-grandas Sep 25, 2024
5882e28
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 25, 2024
e1b8e50
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 30, 2024
0ef245f
Updating packages, improving chain prompt
german-grandas Oct 2, 2024
8558a46
updating rag frontend sha
german-grandas Oct 2, 2024
44b3e72
updating column name
german-grandas Oct 8, 2024
ab06c07
updating max tokens lenght for inference service
german-grandas Oct 24, 2024
4209af8
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion applications/rag/frontend/container/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ WORKDIR /workspace/frontend

RUN pip install -r requirements.txt

CMD ["python", "main.py"]
EXPOSE 8080

ENV FLASK_APP=/workspace/frontend/main.py
ENV PYTHONPATH=.
# Run the application with Gunicorn
CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:8080", "main:app"]
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is required to make Python treat the subfolder as a package
24 changes: 24 additions & 0 deletions applications/rag/frontend/container/application/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from flask import Flask

def create_app():
app = Flask(__name__, static_folder='static', template_folder='templates')
app.jinja_env.trim_blocks = True
app.jinja_env.lstrip_blocks = True
app.config['SECRET_KEY'] = os.environ.get("SECRET_KEY")
gongmax marked this conversation as resolved.
Show resolved Hide resolved

return app
17 changes: 17 additions & 0 deletions applications/rag/frontend/container/application/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .vector_embeddings import VectorEmbeddings

__all__ = ["VectorEmbeddings"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from sqlalchemy import Column, String, Text
from sqlalchemy.orm import mapped_column, declarative_base
from pgvector.sqlalchemy import Vector

Base = declarative_base()

VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "")
gongmax marked this conversation as resolved.
Show resolved Hide resolved


class VectorEmbeddings(Base):
gongmax marked this conversation as resolved.
Show resolved Hide resolved
__tablename__ = VECTOR_EMBEDDINGS_TABLE_NAME

id = Column(String(255), primary_key=True)
text = Column(Text)
text_embedding = mapped_column(Vector(384))
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import logging
from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

from application.utils import post_request

INFERENCE_ENDPOINT = os.environ.get("INFERENCE_ENDPOINT", "127.0.0.1:8081")

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


class HuggingFaceCustomChatModel(LLM):
"""A custom chat model that calls to an HuggingFace TGI API and returns the generated
content based on the given message.

Example:

.. code-block:: python

model = HuggingFaceCustomChatModel()
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
The model output as a string.
"""
if stop is not None:
raise ValueError("stop kwargs are not permitted.")

api_endpoint = f"http://{INFERENCE_ENDPOINT}/generate"
body = {"inputs": prompt}
headers = {"Content-Type": "application/json"}
generated_output = post_request(api_endpoint, body, headers)
generated_text = generated_output.get("generated_text", "")

return generated_text

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.

This method should be overridden by subclasses that support streaming.

If not implemented, the default behavior of calls to stream will be to
fallback to the non-streaming version of the model and return
the output as a single chunk.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
An iterator of GenerationChunks.
"""
api_endpoint = f"http://{INFERENCE_ENDPOINT}/generate_stream"
body = {"inputs": prompt}
headers = {"Content-Type": "application/json"}
logging.info("Calling external model")
generated_output = post_request(api_endpoint, body, headers)
generated_text = generated_output.get("generated_text", "")

for char in generated_text:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "HuggingFaceTGI",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "custom"
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain_core.runnables.history import RunnableWithMessageHistory

from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_google_cloud_sql_pg import PostgresChatMessageHistory


from application.utils import (
create_sync_postgres_engine
)
from application.rag_langchain.huggingface_inference_model import (
HuggingFaceCustomChatModel,
)
from application.vector_storages import CloudSQLVectorStore

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store")

QUESTION = "input"
HISTORY = "chat_history"
CONTEXT = "context"

SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings

template_str = """Answer the question given by the user in no more than 2 sentences.
Use the provided context to improve upon your previous answers. Stick to the facts and be brief. Avoid conversational format.
\n\n
Context: {context}
"""

prompt = ChatPromptTemplate.from_messages(
[
("system", template_str),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

engine = create_sync_postgres_engine()


def get_chat_history(session_id: str) -> PostgresChatMessageHistory:
history = PostgresChatMessageHistory.create_sync(
german-grandas marked this conversation as resolved.
Show resolved Hide resolved
engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME
)

logging.info(
f"Retrieving history for session {session_id} with {len(history.messages)}"
)
return history


def clear_chat_history(session_id: str):
history = PostgresChatMessageHistory.create_sync(
engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME
)
history.clear()


def create_chain() -> RunnableWithMessageHistory:
try:
model = HuggingFaceCustomChatModel()

langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL)
vector_store = CloudSQLVectorStore(langchain_embed, engine)

retriever = vector_store.as_retriever()

setup_and_retrieval = RunnableParallel(
{
"context": retriever,
QUESTION: RunnableLambda(lambda d: d[QUESTION]),
HISTORY: RunnableLambda(lambda d: d[HISTORY]),
}
)

chain = setup_and_retrieval | prompt | model
chain_with_history = RunnableWithMessageHistory(
german-grandas marked this conversation as resolved.
Show resolved Hide resolved
chain,
get_chat_history,
input_messages_key=QUESTION,
history_messages_key=HISTORY,
output_messages_key="output",
)
return chain_with_history
except Exception as e:
logging.info(e)
german-grandas marked this conversation as resolved.
Show resolved Hide resolved
raise e

def take_chat_turn(
chain: RunnableWithMessageHistory, session_id: str, query_text: str
) -> str:
try:
config = {"configurable": {"session_id": session_id}}
result = chain.invoke({"input": query_text}, config=config)
return result
except Exception as e:
logging.info(e)
german-grandas marked this conversation as resolved.
Show resolved Hide resolved
raise e
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is required to make Python treat the subfolder as a package
# This file is required to make Python treat the subfolder as a package
Loading
Loading