Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding chat history to RAG app and refactor to better utilize LangChain #648

Open
wants to merge 61 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
dda40b9
Also introduced a basic session history mechanism in the browser to k…
alpha-amundson May 3, 2024
5cc85b9
tflint formatting fixes
alpha-amundson May 3, 2024
6898666
TPU Provisioner: JobSet related fixes (#645)
nstogner May 6, 2024
1d6c052
Updated image to use code in this branch
alpha-amundson May 6, 2024
981e777
making tflint happy
alpha-amundson May 6, 2024
d1d1211
Working on improvements for rag application (#731)
german-grandas Jul 12, 2024
5a16b54
Rag langchain chat history (#747)
german-grandas Jul 22, 2024
9e416c8
Rag langchain chat history (#755)
german-grandas Jul 29, 2024
e750d12
Fixing issues and updating chat history on frontend
german-grandas Jul 31, 2024
a000c46
Fixing files on working tree
german-grandas Jul 31, 2024
0d853ea
Ignoring test rag, to review how the rag application is working
german-grandas Aug 1, 2024
386c437
ignoring unit test to review cloud build process
german-grandas Aug 1, 2024
be1839d
refactoring cloud sql connection helper
german-grandas Aug 6, 2024
7f081ff
Merge branch 'main' into rag-langchain-chat-history
german-grandas Aug 6, 2024
35f67e4
Change TPU Metrics Source for Autoscaling (#770)
Bslabe123 Aug 8, 2024
0022053
Refactor: move workload identity service account out of kuberay-opera…
genlu2011 Aug 15, 2024
48f655b
updating branch
german-grandas Aug 20, 2024
a9895d6
fixing conflicts with remote branch
german-grandas Aug 20, 2024
cd95c98
fixing conflicts with remote branch
german-grandas Aug 20, 2024
bc8d745
fixing conflicts with remote branch
german-grandas Aug 20, 2024
e9beeef
fixing conflicts applying rebase
german-grandas Aug 20, 2024
eb9ab02
Updating files based on reviewer comments
german-grandas Aug 20, 2024
dff8d94
reverting change on cloudbuild.yaml file
german-grandas Aug 20, 2024
138920f
Reverting comment of line
german-grandas Aug 26, 2024
c8e5d35
Updating length of variable
german-grandas Aug 26, 2024
c437736
updating branch with main
german-grandas Aug 30, 2024
8b4f55d
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 2, 2024
4261818
Updating rag frontend image.
german-grandas Sep 4, 2024
a8258f1
updating rag frontend images with the latest changes
german-grandas Sep 9, 2024
4e1a4c5
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 9, 2024
863ee72
updating branch
german-grandas Sep 9, 2024
4f02546
Fixing issue with database connection
german-grandas Sep 9, 2024
324abf7
Merge branch 'rag-langchain-chat-history' of github.com:GoogleCloudPl…
german-grandas Sep 9, 2024
88ee300
Updating Rag application test.
german-grandas Sep 9, 2024
e209b46
Merge branch 'rag-langchain-chat-history' of https://github.com/Googl…
german-grandas Sep 9, 2024
cf0a447
Adding exceptions to test
german-grandas Sep 9, 2024
bf2f990
Fixing bug on unit test
german-grandas Sep 9, 2024
74b6e9d
fixing unit test
german-grandas Sep 9, 2024
e94cab0
updating notebook to use the PostgresVectorStore instead of the custo…
german-grandas Sep 10, 2024
329417b
fixing issue with notebook
german-grandas Sep 10, 2024
f1bf05a
Fixing issue with missing environment varibles on notebook
german-grandas Sep 11, 2024
88fe07d
Refactoring example notebooks to handle new cloudsql vector store
german-grandas Sep 11, 2024
e38a101
Adding missing package to notebook
german-grandas Sep 11, 2024
799c8db
Creating a notebook for testing rag with a sample of the data
german-grandas Sep 12, 2024
14ff203
updating notebook to test rag
german-grandas Sep 12, 2024
a73f987
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 12, 2024
c01cff6
Reverting changes on files, updating database model on notebook
german-grandas Sep 12, 2024
0a04782
Fixing name with column on notebook query
german-grandas Sep 12, 2024
68a11f7
Merge branch 'main' of github.com:GoogleCloudPlatform/ai-on-gke into …
german-grandas Sep 13, 2024
d872679
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 16, 2024
e9a79ce
Merge branch 'rag-langchain-chat-history' of github.com:GoogleCloudPl…
german-grandas Sep 16, 2024
7fe461c
resolving conflicts
german-grandas Sep 17, 2024
c489d73
Delete applications/rag/example_notebooks/ingest_database.ipynb
german-grandas Sep 17, 2024
aa44fcb
updating Embedding model with missing column
german-grandas Sep 25, 2024
5882e28
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 25, 2024
e1b8e50
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Sep 30, 2024
0ef245f
Updating packages, improving chain prompt
german-grandas Oct 2, 2024
8558a46
updating rag frontend sha
german-grandas Oct 2, 2024
44b3e72
updating column name
german-grandas Oct 8, 2024
ab06c07
updating max tokens lenght for inference service
german-grandas Oct 24, 2024
4209af8
Merge branch 'main' of https://github.com/GoogleCloudPlatform/ai-on-g…
german-grandas Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 97 additions & 39 deletions applications/rag/frontend/container/cloud_sql/cloud_sql.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,43 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import (List, Optional, Iterable, Any)

from google.cloud.sql.connector import Connector, IPTypes
import pymysql
import sqlalchemy
from sentence_transformers import SentenceTransformer
from langchain_core.vectorstores import VectorStore
import pg8000
from langchain_core.embeddings import Embeddings
from langchain_core.documents import Document
from sqlalchemy.engine import Engine
from langchain_google_cloud_sql_pg import PostgresEngine

db = None
VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('TABLE_NAME', '') # CloudSQL table name for vector embeddings
# TODO make this configurable from tf
CHAT_HISTORY_TABLE_NAME = "message_store" # CloudSQL table name where chat history is stored

TABLE_NAME = os.environ.get('TABLE_NAME', '') # CloudSQL table name
INSTANCE_CONNECTION_NAME = os.environ.get('INSTANCE_CONNECTION_NAME', '')
SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings
DB_NAME = "pgvector-database"

PROJECT_ID = os.environ.get('PROJECT_ID', '')
REGION = os.environ.get('REGION', '')
INSTANCE = os.environ.get('INSTANCE', '')

db_username_file = open("/etc/secret-volume/username", "r")
DB_USER = db_username_file.read()
db_username_file.close()
Expand All @@ -23,14 +48,6 @@

transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)

def init_db() -> sqlalchemy.engine.base.Engine:
"""Initiates connection to database and its structure."""
global db
connector = Connector()
if db is None:
db = init_connection_pool(connector)


# helper function to return SQLAlchemy connection pool
def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine:
# function used to generate database connection
Expand All @@ -52,32 +69,73 @@ def getconn() -> pymysql.connections.Connection:
)
return pool

def fetchContext(query_text):
with db.connect() as conn:
try:
results = conn.execute(sqlalchemy.text("SELECT * FROM " + TABLE_NAME)).fetchall()
print(f"query database results:")
for row in results:
print(row)

# chunkify query & fetch matches
query_emb = transformer.encode(query_text).tolist()
query_request = "SELECT id, text, text_embedding, 1 - ('[" + ",".join(map(str, query_emb)) + "]' <=> text_embedding) AS cosine_similarity FROM " + TABLE_NAME + " ORDER BY cosine_similarity DESC LIMIT 5;"
query_results = conn.execute(sqlalchemy.text(query_request)).fetchall()
conn.commit()

if not query_results:
message = f"Table {TABLE_NAME} returned empty result"
raise ValueError(message)
for row in query_results:
print(row)
except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err:
message = f"Table {TABLE_NAME} does not exist: {err}"
raise sqlalchemy.exc.DataError(message)
except sqlalchemy.exc.DatabaseError as err:
message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}"
raise sqlalchemy.exc.DataError(message)
except Exception as err:
raise Exception(f"General error: {err}")

return query_results[0][1]
def create_sync_postgres_engine():
engine = PostgresEngine.from_instance(
project_id=PROJECT_ID,
region=REGION,
instance=INSTANCE,
database=DB_NAME,
user=DB_USER,
password=DB_PASS,
ip_type=IPTypes.PRIVATE
)
engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME)
return engine

#TODO replace this with the Cloud SQL vector store for langchain,
# once the notebook also uses it (and creates the correct schema)
class CustomVectorStore(VectorStore):
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
):
raise NotImplementedError

def __init__(self, embedding: Embeddings, engine: Engine):
self.embedding = embedding
self.engine = engine

@property
def embeddings(self) -> Embeddings:
return self.embedding


# TODO implement
def add_texts(self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any) -> List[str]:
raise NotImplementedError

#TODO implement similarity search with cosine similarity threshold

def similarity_search(self, query: dict, k: int = 4, **kwargs: Any) -> List[Document]:
with self.engine.connect() as conn:
try:
q = query["question"]
# embed query & fetch matches
query_emb = self.embedding.embed_query(q)
emb_str = ",".join(map(str, query_emb))
query_request = f"""SELECT id, text, 1 - ('[{emb_str}]' <=> text_embedding) AS cosine_similarity
FROM {VECTOR_EMBEDDINGS_TABLE_NAME}
ORDER BY cosine_similarity DESC LIMIT {k};"""
query_results = conn.execute(sqlalchemy.text(query_request)).fetchall()
print(f"GOT {len(query_results)} results")
conn.commit()

if not query_results:
message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} returned empty result"
raise ValueError(message)
except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err:
message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}"
raise sqlalchemy.exc.DataError(message)
except sqlalchemy.exc.DatabaseError as err:
message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}"
raise sqlalchemy.exc.DataError(message)
except Exception as err:
raise Exception(f"General error: {err}")

#convert query results into List[Document]
texts = [result[1] for result in query_results]
return [Document(page_content=text) for text in texts]
96 changes: 43 additions & 53 deletions applications/rag/frontend/container/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,60 +16,33 @@
import logging as log
import google.cloud.logging as logging
import traceback
import uuid

from flask import Flask, render_template, request, jsonify
from langchain.chains import LLMChain
from langchain.llms import HuggingFaceTextGenInference
from langchain.prompts import PromptTemplate
from flask import Flask, render_template, request, jsonify, session
from rai import dlp_filter # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp
from rai import nlp_filter # https://cloud.google.com/natural-language/docs/moderating-text
from cloud_sql import cloud_sql
import sqlalchemy
from rag_langchain.rag_chain import clear_chat_history, create_chain, take_chat_turn, engine
from datetime import datetime, timedelta, timezone

# Setup logging
logging_client = logging.Client()
logging_client.setup_logging()

# TODO: refactor the app startup code into a flask app factory
# TODO: include the chat history cache in the app lifecycle and ensure that it's threadsafe.
app = Flask(__name__, static_folder='static')
app.jinja_env.trim_blocks = True
app.jinja_env.lstrip_blocks = True
app.config['ENGINE'] = engine # force the connection pool to warm up eagerly

# initialize parameters
INFERENCE_ENDPOINT=os.environ.get('INFERENCE_ENDPOINT', '127.0.0.1:8081')

llm = HuggingFaceTextGenInference(
inference_server_url=f'http://{INFERENCE_ENDPOINT}/',
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
)

prompt_template = """
### [INST]
Instruction: Always assist with care, respect, and truth. Respond with utmost utility yet securely.
Avoid harmful, unethical, prejudiced, or negative content.
Ensure replies promote fairness and positivity.
Here is context to help:

{context}

### QUESTION:
{user_prompt}

[/INST]
"""

# Create prompt from prompt template
prompt = PromptTemplate(
input_variables=["context", "user_prompt"],
template=prompt_template,
)
SESSION_TIMEOUT_MINUTES = 30
#TODO replace with real secret
SECRET_KEY = "TODO replace this with an actual secret that is stored and managed by kubernetes and added to the terraform configuration."
app.config['SECRET_KEY'] = SECRET_KEY

# Create llm chain
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = create_chain()

@app.route('/get_nlp_status', methods=['GET'])
def get_nlp_status():
Expand All @@ -80,6 +53,7 @@
def get_dlp_status():
dlp_enabled = dlp_filter.is_dlp_api_enabled()
return jsonify({"dlpEnabled": dlp_enabled})

@app.route('/get_inspect_templates')
def get_inspect_templates():
return jsonify(dlp_filter.list_inspect_templates_from_parent())
Expand All @@ -89,15 +63,36 @@
return jsonify(dlp_filter.list_deidentify_templates_from_parent())

@app.before_request
def init_db():
cloud_sql.init_db()
def check_new_session():
if 'session_id' not in session:
# instantiate a new session using a generated UUID
session_id = str(uuid.uuid4())
session['session_id'] = session_id

@app.before_request
def check_inactivity():
# Inactivity cleanup
if 'last_activity' in session:
time_elapsed = datetime.now(timezone.utc) - session['last_activity']

if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES):
print("Session inactive: Cleaning up resources...")
session_id = session['session_id']
# TODO: implement garbage collection process for idle sessions that have timed out
clear_chat_history(session_id)
session.clear()

# Always update the 'last_activity' data
session['last_activity'] = datetime.now(timezone.utc)

@app.route('/')
def index():
return render_template('index.html')

@app.route('/prompt', methods=['POST'])
def handlePrompt():
# TODO on page refresh, load chat history into browser.
session['last_activity'] = datetime.now(timezone.utc)
data = request.get_json()
warnings = []

Expand All @@ -107,19 +102,12 @@
user_prompt = data['prompt']
log.info(f"handle user prompt: {user_prompt}")

context = ""
try:
context = cloud_sql.fetchContext(user_prompt)
except Exception as err:
error_traceback = traceback.format_exc()
log.warn(f"Error: {err}\nTraceback:\n{error_traceback}")
warnings.append(f"Error: {err}\nTraceback:\n{error_traceback}")
response = {}
result = take_chat_turn(llm_chain, session['session_id'], user_prompt)
response['text'] = result

try:
response = llm_chain.invoke({
"context": context,
"user_prompt": user_prompt
})
# TODO: enable filtering in chain
if 'nlpFilterLevel' in data:
if nlp_filter.is_content_inappropriate(response['text'], data['nlpFilterLevel']):
response['text'] = 'The response is deemed inappropriate for display.'
Expand Down Expand Up @@ -149,4 +137,6 @@


if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))
# TODO using gunicorn to start the server results in the first request being really slow.
# Sometimes, the worker thread has to restart due to an unknown error.
app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))
Fixed Show fixed Hide fixed
Loading
Loading