From dda40b9f64c9d219362608c4081b929359c583fc Mon Sep 17 00:00:00 2001 From: Julie Amundson Date: Fri, 3 May 2024 16:05:57 -0700 Subject: [PATCH 01/46] Also introduced a basic session history mechanism in the browser to keep track of and retrieve chat history from Cloud SQL. main.py - removed old langchain and logic to retrieve context. replaced with new chain from rag_chain.py. Introduced browser session with 30 minute ttl. Storing session ID in the session cookie. Session ID is then used to retrieve chat history. Chat history is cleared when timeout is reached. cloud_sql.py - now includes a method to create a PostgresEngine for storing and retrieving history, plus a CustomVectorStore to perform the query embedding and vector search. Old code paths no longer needed were removed. rag_chain.py - contains helper method create_chain to create, update and delete the end-to-end RAG chain with history. various tf files: increased max input and total tokens on HF TGI for mistral. threadded through some parameters needed to instantiate the PostgresEngine. requirements.txt - added some dependencies needed for langchain --- .../frontend/container/cloud_sql/cloud_sql.py | 136 +++++++++++++----- applications/rag/frontend/container/main.py | 96 ++++++------- .../container/rag_langchain/rag_chain.py | 136 ++++++++++++++++++ .../rag/frontend/container/requirements.txt | 2 + applications/rag/frontend/main.tf | 19 ++- applications/rag/main.tf | 3 +- tutorials-and-examples/hf-tgi/main.tf | 16 ++- 7 files changed, 311 insertions(+), 97 deletions(-) create mode 100644 applications/rag/frontend/container/rag_langchain/rag_chain.py diff --git a/applications/rag/frontend/container/cloud_sql/cloud_sql.py b/applications/rag/frontend/container/cloud_sql/cloud_sql.py index 97bc5fe0c..4a9186b38 100644 --- a/applications/rag/frontend/container/cloud_sql/cloud_sql.py +++ b/applications/rag/frontend/container/cloud_sql/cloud_sql.py @@ -1,18 +1,43 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os +from typing import (List, Optional, Iterable, Any) from google.cloud.sql.connector import Connector, IPTypes import pymysql import sqlalchemy from sentence_transformers import SentenceTransformer +from langchain_core.vectorstores import VectorStore import pg8000 +from langchain_core.embeddings import Embeddings +from langchain_core.documents import Document +from sqlalchemy.engine import Engine +from langchain_google_cloud_sql_pg import PostgresEngine -db = None +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('TABLE_NAME', '') # CloudSQL table name for vector embeddings +# TODO make this configurable from tf +CHAT_HISTORY_TABLE_NAME = "message_store" # CloudSQL table name where chat history is stored -TABLE_NAME = os.environ.get('TABLE_NAME', '') # CloudSQL table name INSTANCE_CONNECTION_NAME = os.environ.get('INSTANCE_CONNECTION_NAME', '') SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings DB_NAME = "pgvector-database" +PROJECT_ID = os.environ.get('PROJECT_ID', '') +REGION = os.environ.get('REGION', '') +INSTANCE = os.environ.get('INSTANCE', '') + db_username_file = open("/etc/secret-volume/username", "r") DB_USER = db_username_file.read() db_username_file.close() @@ -23,14 +48,6 @@ transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) -def init_db() -> sqlalchemy.engine.base.Engine: - """Initiates connection to database and its structure.""" - global db - connector = Connector() - if db is None: - db = init_connection_pool(connector) - - # helper function to return SQLAlchemy connection pool def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: # function used to generate database connection @@ -52,32 +69,73 @@ def getconn() -> pymysql.connections.Connection: ) return pool -def fetchContext(query_text): - with db.connect() as conn: - try: - results = conn.execute(sqlalchemy.text("SELECT * FROM " + TABLE_NAME)).fetchall() - print(f"query database results:") - for row in results: - print(row) - - # chunkify query & fetch matches - query_emb = transformer.encode(query_text).tolist() - query_request = "SELECT id, text, text_embedding, 1 - ('[" + ",".join(map(str, query_emb)) + "]' <=> text_embedding) AS cosine_similarity FROM " + TABLE_NAME + " ORDER BY cosine_similarity DESC LIMIT 5;" - query_results = conn.execute(sqlalchemy.text(query_request)).fetchall() - conn.commit() - - if not query_results: - message = f"Table {TABLE_NAME} returned empty result" - raise ValueError(message) - for row in query_results: - print(row) - except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: - message = f"Table {TABLE_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except sqlalchemy.exc.DatabaseError as err: - message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except Exception as err: - raise Exception(f"General error: {err}") - - return query_results[0][1] \ No newline at end of file +def create_sync_postgres_engine(): + engine = PostgresEngine.from_instance( + project_id=PROJECT_ID, + region=REGION, + instance=INSTANCE, + database=DB_NAME, + user=DB_USER, + password=DB_PASS, + ip_type=IPTypes.PRIVATE + ) + engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) + return engine + +#TODO replace this with the Cloud SQL vector store for langchain, +# once the notebook also uses it (and creates the correct schema) +class CustomVectorStore(VectorStore): + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ): + raise NotImplementedError + + def __init__(self, embedding: Embeddings, engine: Engine): + self.embedding = embedding + self.engine = engine + + @property + def embeddings(self) -> Embeddings: + return self.embedding + + + # TODO implement + def add_texts(self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any) -> List[str]: + raise NotImplementedError + + #TODO implement similarity search with cosine similarity threshold + + def similarity_search(self, query: dict, k: int = 4, **kwargs: Any) -> List[Document]: + with self.engine.connect() as conn: + try: + q = query["question"] + # embed query & fetch matches + query_emb = self.embedding.embed_query(q) + emb_str = ",".join(map(str, query_emb)) + query_request = f"""SELECT id, text, 1 - ('[{emb_str}]' <=> text_embedding) AS cosine_similarity + FROM {VECTOR_EMBEDDINGS_TABLE_NAME} + ORDER BY cosine_similarity DESC LIMIT {k};""" + query_results = conn.execute(sqlalchemy.text(query_request)).fetchall() + print(f"GOT {len(query_results)} results") + conn.commit() + + if not query_results: + message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} returned empty result" + raise ValueError(message) + except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: + message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except sqlalchemy.exc.DatabaseError as err: + message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except Exception as err: + raise Exception(f"General error: {err}") + + #convert query results into List[Document] + texts = [result[1] for result in query_results] + return [Document(page_content=text) for text in texts] \ No newline at end of file diff --git a/applications/rag/frontend/container/main.py b/applications/rag/frontend/container/main.py index 33c62d2f5..77a510d26 100644 --- a/applications/rag/frontend/container/main.py +++ b/applications/rag/frontend/container/main.py @@ -16,60 +16,33 @@ import logging as log import google.cloud.logging as logging import traceback +import uuid -from flask import Flask, render_template, request, jsonify -from langchain.chains import LLMChain -from langchain.llms import HuggingFaceTextGenInference -from langchain.prompts import PromptTemplate +from flask import Flask, render_template, request, jsonify, session from rai import dlp_filter # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp from rai import nlp_filter # https://cloud.google.com/natural-language/docs/moderating-text from cloud_sql import cloud_sql -import sqlalchemy +from rag_langchain.rag_chain import clear_chat_history, create_chain, take_chat_turn, engine +from datetime import datetime, timedelta, timezone # Setup logging logging_client = logging.Client() logging_client.setup_logging() +# TODO: refactor the app startup code into a flask app factory +# TODO: include the chat history cache in the app lifecycle and ensure that it's threadsafe. app = Flask(__name__, static_folder='static') app.jinja_env.trim_blocks = True app.jinja_env.lstrip_blocks = True +app.config['ENGINE'] = engine # force the connection pool to warm up eagerly -# initialize parameters -INFERENCE_ENDPOINT=os.environ.get('INFERENCE_ENDPOINT', '127.0.0.1:8081') - -llm = HuggingFaceTextGenInference( - inference_server_url=f'http://{INFERENCE_ENDPOINT}/', - max_new_tokens=512, - top_k=10, - top_p=0.95, - typical_p=0.95, - temperature=0.01, - repetition_penalty=1.03, -) - -prompt_template = """ -### [INST] -Instruction: Always assist with care, respect, and truth. Respond with utmost utility yet securely. -Avoid harmful, unethical, prejudiced, or negative content. -Ensure replies promote fairness and positivity. -Here is context to help: - -{context} - -### QUESTION: -{user_prompt} - -[/INST] - """ - -# Create prompt from prompt template -prompt = PromptTemplate( - input_variables=["context", "user_prompt"], - template=prompt_template, -) +SESSION_TIMEOUT_MINUTES = 30 +#TODO replace with real secret +SECRET_KEY = "TODO replace this with an actual secret that is stored and managed by kubernetes and added to the terraform configuration." +app.config['SECRET_KEY'] = SECRET_KEY # Create llm chain -llm_chain = LLMChain(llm=llm, prompt=prompt) +llm_chain = create_chain() @app.route('/get_nlp_status', methods=['GET']) def get_nlp_status(): @@ -80,6 +53,7 @@ def get_nlp_status(): def get_dlp_status(): dlp_enabled = dlp_filter.is_dlp_api_enabled() return jsonify({"dlpEnabled": dlp_enabled}) + @app.route('/get_inspect_templates') def get_inspect_templates(): return jsonify(dlp_filter.list_inspect_templates_from_parent()) @@ -89,8 +63,27 @@ def get_deidentify_templates(): return jsonify(dlp_filter.list_deidentify_templates_from_parent()) @app.before_request -def init_db(): - cloud_sql.init_db() +def check_new_session(): + if 'session_id' not in session: + # instantiate a new session using a generated UUID + session_id = str(uuid.uuid4()) + session['session_id'] = session_id + +@app.before_request +def check_inactivity(): + # Inactivity cleanup + if 'last_activity' in session: + time_elapsed = datetime.now(timezone.utc) - session['last_activity'] + + if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): + print("Session inactive: Cleaning up resources...") + session_id = session['session_id'] + # TODO: implement garbage collection process for idle sessions that have timed out + clear_chat_history(session_id) + session.clear() + + # Always update the 'last_activity' data + session['last_activity'] = datetime.now(timezone.utc) @app.route('/') def index(): @@ -98,6 +91,8 @@ def index(): @app.route('/prompt', methods=['POST']) def handlePrompt(): + # TODO on page refresh, load chat history into browser. + session['last_activity'] = datetime.now(timezone.utc) data = request.get_json() warnings = [] @@ -107,19 +102,12 @@ def handlePrompt(): user_prompt = data['prompt'] log.info(f"handle user prompt: {user_prompt}") - context = "" try: - context = cloud_sql.fetchContext(user_prompt) - except Exception as err: - error_traceback = traceback.format_exc() - log.warn(f"Error: {err}\nTraceback:\n{error_traceback}") - warnings.append(f"Error: {err}\nTraceback:\n{error_traceback}") + response = {} + result = take_chat_turn(llm_chain, session['session_id'], user_prompt) + response['text'] = result - try: - response = llm_chain.invoke({ - "context": context, - "user_prompt": user_prompt - }) + # TODO: enable filtering in chain if 'nlpFilterLevel' in data: if nlp_filter.is_content_inappropriate(response['text'], data['nlpFilterLevel']): response['text'] = 'The response is deemed inappropriate for display.' @@ -149,4 +137,6 @@ def handlePrompt(): if __name__ == '__main__': - app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 8080))) + # TODO using gunicorn to start the server results in the first request being really slow. + # Sometimes, the worker thread has to restart due to an unknown error. + app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 8080))) \ No newline at end of file diff --git a/applications/rag/frontend/container/rag_langchain/rag_chain.py b/applications/rag/frontend/container/rag_langchain/rag_chain.py new file mode 100644 index 000000000..807c875c8 --- /dev/null +++ b/applications/rag/frontend/container/rag_langchain/rag_chain.py @@ -0,0 +1,136 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import (Dict) +from cloud_sql.cloud_sql import CHAT_HISTORY_TABLE_NAME, init_connection_pool, create_sync_postgres_engine, CustomVectorStore +from google.cloud.sql.connector import Connector +from langchain_community.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnableParallel, RunnableLambda +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_google_cloud_sql_pg import PostgresChatMessageHistory + +QUESTION = "question" +HISTORY = "history" +CONTEXT = "context" + +INFERENCE_ENDPOINT=os.environ.get('INFERENCE_ENDPOINT', '127.0.0.1:8081') +SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings + + +# TODO use a chat model instead of an LLM in the chain. Convert the prompt to a chat prompt template +# prompt = ChatPromptTemplate.from_messages( +# [ +# ("system", """You help everyone by answering questions, and improve your answers from previous answers in history. +# You stick to the facts by basing your answers off of the context provided:"""), +# MessagesPlaceholder(variable_name="history"), +# MessagesPlaceholder(variable_name="context"), +# ("human", "{question}"), +# ] +# ) +template = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. +Improve upon your previous answers using History, a list of messages. +Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. +Stick to the facts by basing your answers off of the Context provided. +Be brief in answering. +History: {""" + HISTORY + "}\n\nContext: {" + CONTEXT + "}\n\nQuestion: {" + QUESTION + "}\n" + +prompt = PromptTemplate(template=template, input_variables=[HISTORY, CONTEXT, QUESTION]) + +engine = create_sync_postgres_engine() +# TODO: Dict is not safe for multiprocessing. Introduce a cache using Flask-caching or libcache +# The in-memory SimpleCache implementations for each of these libraries is not safe either. +# Consider redis or memcached (e.g., Memorystore) +# chat_history_map: Dict[str, PostgresChatMessageHistory] = {} + +def get_chat_history(session_id: str) -> PostgresChatMessageHistory: + history = PostgresChatMessageHistory.create_sync( + engine, + session_id=session_id, + table_name = CHAT_HISTORY_TABLE_NAME + ) + + print(f"Retrieving history for session {session_id} with {len(history.messages)}") + return history + +def clear_chat_history(session_id: str): + history = PostgresChatMessageHistory.create_sync( + engine, + session_id=session_id, + table_name = CHAT_HISTORY_TABLE_NAME + ) + history.clear() + + +#TODO: limit number of tokens in prompt to MAX_INPUT_LENGTH +# (as specified in hugging face TGI input parameter) + +def create_chain() -> RunnableWithMessageHistory: + # TODO HuggingFaceTextGenInference class is deprecated. + # The warning is: + # The class `langchain_community.llms.huggingface_text_gen_inference.HuggingFaceTextGenInference` + # was deprecated in langchain-community 0.0.21 and will be removed in 0.2.0. Use HuggingFaceEndpoint instead + # The replacement is HuggingFace Endoint, which requires a huggingface + # hub API token. Either need to add the token to the environment, or need to find a method to call TGI + # without the token. + # Example usage of HuggingFaceEndpoint: + # llm = HuggingFaceEndpoint( + # endpoint_url=f'http://{INFERENCE_ENDPOINT}/', + # max_new_tokens=512, + # top_k=10, + # top_p=0.95, + # typical_p=0.95, + # temperature=0.01, + # repetition_penalty=1.03, + # huggingfacehub_api_token="my-api-key" + # ) + # TODO: Give guidance on what these parameters should be and describe why these values were chosen. + model = HuggingFaceTextGenInference( + inference_server_url=f'http://{INFERENCE_ENDPOINT}/', + max_new_tokens=512, + top_k=10, + top_p=0.95, + typical_p=0.95, + temperature=0.01, + repetition_penalty=1.03, + ) + + langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) + vector_store = CustomVectorStore(langchain_embed, init_connection_pool(Connector())) + retriever = vector_store.as_retriever() + + setup_and_retrieval = RunnableParallel( + { + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]) + } + ) + chain = setup_and_retrieval | prompt | model + chain_with_history = RunnableWithMessageHistory( + chain, + get_chat_history, + input_messages_key=QUESTION, + history_messages_key=HISTORY, + output_messages_key="output" + ) + return chain_with_history + +def take_chat_turn(chain: RunnableWithMessageHistory, session_id: str, query_text: str) -> str: + #TODO limit the number of history messages + config = {"configurable": {"session_id": session_id}} + result = chain.invoke({"question": query_text}, config) + return str(result) \ No newline at end of file diff --git a/applications/rag/frontend/container/requirements.txt b/applications/rag/frontend/container/requirements.txt index 3f1ecd0c4..8b4a16779 100644 --- a/applications/rag/frontend/container/requirements.txt +++ b/applications/rag/frontend/container/requirements.txt @@ -32,3 +32,5 @@ google-cloud-logging==3.9.0 google-api-python-client==2.114.0 pymysql==1.1.0 cloud-sql-python-connector[pg8000]==1.7.0 +langchain-google-cloud-sql-pg==0.4.0 +langchain-community==0.0.31 \ No newline at end of file diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 732fe8a54..1b1fec2d6 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -110,6 +110,8 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { service_account_name = var.google_service_account container { image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" + # Built from local code. Revert before submitting. + # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:108bb16ee2278255c80524fce125ef349c494cb5bc4ca77dbde5048b8f9448c1" name = "rag-frontend" port { @@ -123,8 +125,19 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { } env { - name = "PROJECT_ID" - value = "projects/${var.project_id}" + name = "PROJECT_ID" + #value = "projects/${var.project_id}" + value = var.project_id + } + + env { + name = "REGION" + value = var.region + } + + env { + name = "INSTANCE" + value = var.cloudsql_instance } env { @@ -190,4 +203,4 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { } } } -} +} \ No newline at end of file diff --git a/applications/rag/main.tf b/applications/rag/main.tf index 1cf146df4..6ff27fe8f 100644 --- a/applications/rag/main.tf +++ b/applications/rag/main.tf @@ -284,6 +284,7 @@ module "frontend" { source = "./frontend" providers = { helm = helm.rag, kubernetes = kubernetes.rag } project_id = var.project_id + region = var.cluster_location create_service_account = var.create_rag_service_account google_service_account = local.rag_service_account namespace = local.kubernetes_namespace @@ -309,4 +310,4 @@ module "frontend" { domain = var.frontend_domain members_allowlist = var.frontend_members_allowlist != "" ? split(",", var.frontend_members_allowlist) : [] depends_on = [module.namespace] -} +} \ No newline at end of file diff --git a/tutorials-and-examples/hf-tgi/main.tf b/tutorials-and-examples/hf-tgi/main.tf index 557c5afd6..064e6f51d 100644 --- a/tutorials-and-examples/hf-tgi/main.tf +++ b/tutorials-and-examples/hf-tgi/main.tf @@ -105,6 +105,20 @@ resource "kubernetes_deployment" "inference_deployment" { value = "/model/Mistral-7B-Instruct-v0.1" } + env { + # Extends the max size of the prompt we can send to the service, + # so that we can augment prompts and add chat history without causing errors. + name = "MAX_INPUT_LENGTH" + value = 3072 + } + + env { + # Extends the overall context window (including length of prompt & response combined) + # Both this limit and MAX_INPUT_LENGTH need to be increased to enable RAG and chat history. + name = "MAX_TOTAL_TOKENS" + value = 4096 + } + env { name = "NUM_SHARD" value = "2" @@ -184,4 +198,4 @@ resource "kubernetes_deployment" "inference_deployment" { } } } -} +} \ No newline at end of file From 5cc85b99ff2057757363c28f8e63e4d2699d61aa Mon Sep 17 00:00:00 2001 From: Julie Amundson Date: Fri, 3 May 2024 16:18:13 -0700 Subject: [PATCH 02/46] tflint formatting fixes --- applications/rag/frontend/main.tf | 2 +- tutorials-and-examples/hf-tgi/main.tf | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 1b1fec2d6..a88081926 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -112,7 +112,7 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" # Built from local code. Revert before submitting. # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:108bb16ee2278255c80524fce125ef349c494cb5bc4ca77dbde5048b8f9448c1" - name = "rag-frontend" + name = "rag-frontend" port { container_port = 8080 diff --git a/tutorials-and-examples/hf-tgi/main.tf b/tutorials-and-examples/hf-tgi/main.tf index 064e6f51d..acc999275 100644 --- a/tutorials-and-examples/hf-tgi/main.tf +++ b/tutorials-and-examples/hf-tgi/main.tf @@ -108,14 +108,14 @@ resource "kubernetes_deployment" "inference_deployment" { env { # Extends the max size of the prompt we can send to the service, # so that we can augment prompts and add chat history without causing errors. - name = "MAX_INPUT_LENGTH" + name = "MAX_INPUT_LENGTH" value = 3072 } env { # Extends the overall context window (including length of prompt & response combined) # Both this limit and MAX_INPUT_LENGTH need to be increased to enable RAG and chat history. - name = "MAX_TOTAL_TOKENS" + name = "MAX_TOTAL_TOKENS" value = 4096 } From 68986667f8109e0ed99b0758808e042e88eaa88a Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Mon, 6 May 2024 13:35:27 -0400 Subject: [PATCH 03/46] TPU Provisioner: JobSet related fixes (#645) --- tpu-provisioner/cmd/main.go | 2 + tpu-provisioner/config/rbac/role.yaml | 8 +++ tpu-provisioner/go.mod | 6 +- tpu-provisioner/go.sum | 56 +------------------ .../controller/deletion_controller.go | 1 + 5 files changed, 14 insertions(+), 59 deletions(-) diff --git a/tpu-provisioner/cmd/main.go b/tpu-provisioner/cmd/main.go index 6b13abf9f..b876926cd 100644 --- a/tpu-provisioner/cmd/main.go +++ b/tpu-provisioner/cmd/main.go @@ -53,6 +53,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/metrics/server" "sigs.k8s.io/controller-runtime/pkg/webhook" + jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2" //+kubebuilder:scaffold:imports ) @@ -63,6 +64,7 @@ var ( func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(jobset.AddToScheme(scheme)) //+kubebuilder:scaffold:scheme } diff --git a/tpu-provisioner/config/rbac/role.yaml b/tpu-provisioner/config/rbac/role.yaml index 1a8bc10f1..73af5a5fe 100644 --- a/tpu-provisioner/config/rbac/role.yaml +++ b/tpu-provisioner/config/rbac/role.yaml @@ -64,3 +64,11 @@ rules: - get - patch - update +- apiGroups: + - jobset.x-k8s.io + resources: + - jobsets + verbs: + - get + - list + - watch diff --git a/tpu-provisioner/go.mod b/tpu-provisioner/go.mod index 6cbddff70..e2f7f6bfe 100644 --- a/tpu-provisioner/go.mod +++ b/tpu-provisioner/go.mod @@ -1,9 +1,10 @@ module github.com/GoogleCloudPlatform/ai-on-gke/tpu-provisioner -go 1.22 +go 1.22.0 require ( cloud.google.com/go/compute/metadata v0.3.0 + github.com/google/go-cmp v0.6.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/onsi/ginkgo/v2 v2.17.1 github.com/onsi/gomega v1.32.0 @@ -38,7 +39,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect - github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect github.com/google/s2a-go v0.1.7 // indirect @@ -49,7 +49,6 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect @@ -82,7 +81,6 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/apiextensions-apiserver v0.30.0 // indirect - k8s.io/component-base v0.30.0 // indirect k8s.io/kube-openapi v0.0.0-20240423202451-8948a665c108 // indirect k8s.io/utils v0.0.0-20240423183400-0849a56e8f22 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect diff --git a/tpu-provisioner/go.sum b/tpu-provisioner/go.sum index 3c348bd77..6599ce1cf 100644 --- a/tpu-provisioner/go.sum +++ b/tpu-provisioner/go.sum @@ -9,8 +9,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -18,12 +16,9 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= -github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emicklei/go-restful/v3 v3.12.0 h1:y2DdzBAURM29NFF94q6RaY4vjIH1rtwDapwQtU84iWk= github.com/emicklei/go-restful/v3 v3.12.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -32,8 +27,6 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v5.6.0+incompatible h1:jBYDEEiFBPxA0v50tFdvOzQQTCvpL6mnFh5mB2/l16U= github.com/evanphx/json-patch v5.6.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= -github.com/evanphx/json-patch/v5 v5.8.0 h1:lRj6N9Nci7MvzrXuX6HFzU8XjmhPiXPlsKEy1u0KQro= -github.com/evanphx/json-patch/v5 v5.8.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -41,23 +34,16 @@ github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= -github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= -github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= -github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE= -github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= -github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= @@ -116,18 +102,12 @@ github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dv github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -143,26 +123,17 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= -github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= -github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= -github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+aLCE= github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/prometheus/procfs v0.14.0 h1:Lw4VdGGoKEZilJsayHf0B+9YgLGREba2C6xr+Fdfq6s= github.com/prometheus/procfs v0.14.0/go.mod h1:XL+Iwz8k8ZabyZfMFHPiilCniixqQarAy5Mu67pHlNQ= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -199,8 +170,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -252,8 +221,6 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY= golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -305,41 +272,20 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -k8s.io/api v0.29.3 h1:2ORfZ7+bGC3YJqGpV0KSDDEVf8hdGQ6A03/50vj8pmw= -k8s.io/api v0.29.3/go.mod h1:y2yg2NTyHUUkIoTC+phinTnEa3KFM6RZ3szxt014a80= k8s.io/api v0.30.0 h1:siWhRq7cNjy2iHssOB9SCGNCl2spiF1dO3dABqZ8niA= k8s.io/api v0.30.0/go.mod h1:OPlaYhoHs8EQ1ql0R/TsUgaRPhpKNxIMrKQfWUp8QSE= -k8s.io/apiextensions-apiserver v0.29.2 h1:UK3xB5lOWSnhaCk0RFZ0LUacPZz9RY4wi/yt2Iu+btg= -k8s.io/apiextensions-apiserver v0.29.2/go.mod h1:aLfYjpA5p3OwtqNXQFkhJ56TB+spV8Gc4wfMhUA3/b8= k8s.io/apiextensions-apiserver v0.30.0 h1:jcZFKMqnICJfRxTgnC4E+Hpcq8UEhT8B2lhBcQ+6uAs= k8s.io/apiextensions-apiserver v0.30.0/go.mod h1:N9ogQFGcrbWqAY9p2mUAL5mGxsLqwgtUce127VtRX5Y= -k8s.io/apimachinery v0.29.3 h1:2tbx+5L7RNvqJjn7RIuIKu9XTsIZ9Z5wX2G22XAa5EU= -k8s.io/apimachinery v0.29.3/go.mod h1:hx/S4V2PNW4OMg3WizRrHutyB5la0iCUbZym+W0EQIU= k8s.io/apimachinery v0.30.0 h1:qxVPsyDM5XS96NIh9Oj6LavoVFYff/Pon9cZeDIkHHA= k8s.io/apimachinery v0.30.0/go.mod h1:iexa2somDaxdnj7bha06bhb43Zpa6eWH8N8dbqVjTUc= -k8s.io/client-go v0.29.3 h1:R/zaZbEAxqComZ9FHeQwOh3Y1ZUs7FaHKZdQtIc2WZg= -k8s.io/client-go v0.29.3/go.mod h1:tkDisCvgPfiRpxGnOORfkljmS+UrW+WtXAy2fTvXJB0= k8s.io/client-go v0.30.0 h1:sB1AGGlhY/o7KCyCEQ0bPWzYDL0pwOZO4vAtTSh/gJQ= k8s.io/client-go v0.30.0/go.mod h1:g7li5O5256qe6TYdAMyX/otJqMhIiGgTapdLchhmOaY= -k8s.io/component-base v0.29.2 h1:lpiLyuvPA9yV1aQwGLENYyK7n/8t6l3nn3zAtFTJYe8= -k8s.io/component-base v0.29.2/go.mod h1:BfB3SLrefbZXiBfbM+2H1dlat21Uewg/5qtKOl8degM= -k8s.io/component-base v0.30.0 h1:cj6bp38g0ainlfYtaOQuRELh5KSYjhKxM+io7AUIk4o= -k8s.io/component-base v0.30.0/go.mod h1:V9x/0ePFNaKeKYA3bOvIbrNoluTSG+fSJKjLdjOoeXQ= -k8s.io/klog v1.0.0 h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8= -k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= -k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= k8s.io/klog/v2 v2.120.1 h1:QXU6cPEOIslTGvZaXvFWiP9VKyeet3sawzTOvdXb4Vw= k8s.io/klog/v2 v2.120.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= -k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= -k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00/go.mod h1:AsvuZPBlUDVuCdzJ87iajxtXuR9oktsTctW/R9wwouA= k8s.io/kube-openapi v0.0.0-20240423202451-8948a665c108 h1:Q8Z7VlGhcJgBHJHYugJ/K/7iB8a2eSxCyxdVjJp+lLY= k8s.io/kube-openapi v0.0.0-20240423202451-8948a665c108/go.mod h1:yD4MZYeKMBwQKVht279WycxKyM84kkAx2DPrTXaeb98= -k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI= -k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= k8s.io/utils v0.0.0-20240423183400-0849a56e8f22 h1:ao5hUqGhsqdm+bYbjH/pRkCs0unBGe9UyDahzs9zQzQ= k8s.io/utils v0.0.0-20240423183400-0849a56e8f22/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= -sigs.k8s.io/controller-runtime v0.17.3 h1:65QmN7r3FWgTxDMz9fvGnO1kbf2nu+acg9p2R9oYYYk= -sigs.k8s.io/controller-runtime v0.17.3/go.mod h1:N0jpP5Lo7lMTF9aL56Z/B2oWBJjey6StQM0jRbKQXtY= sigs.k8s.io/controller-runtime v0.18.0 h1:Z7jKuX784TQSUL1TIyeuF7j8KXZ4RtSX0YgtjKcSTME= sigs.k8s.io/controller-runtime v0.18.0/go.mod h1:tuAt1+wbVsXIT8lPtk5RURxqAnq7xkpv2Mhttslg7Hw= sigs.k8s.io/jobset v0.5.0 h1:IwsJNut1yhN74Iauk1aDR9P/vyMqTzJ0ErAls62iR5U= diff --git a/tpu-provisioner/internal/controller/deletion_controller.go b/tpu-provisioner/internal/controller/deletion_controller.go index b379649db..11e3112fe 100644 --- a/tpu-provisioner/internal/controller/deletion_controller.go +++ b/tpu-provisioner/internal/controller/deletion_controller.go @@ -49,6 +49,7 @@ type NodeCriteria struct { //+kubebuilder:rbac:groups="",resources=nodes,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups="",resources=nodes/status,verbs=get;update;patch //+kubebuilder:rbac:groups="",resources=nodes/finalizers,verbs=update +//+kubebuilder:rbac:groups="jobset.x-k8s.io",resources=jobsets,verbs=get;list;watch func (r *DeletionReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { lg := ctrllog.FromContext(ctx) From 1d6c0522707d725e4411f7fb9726fb6589a4133b Mon Sep 17 00:00:00 2001 From: Julie Amundson Date: Mon, 6 May 2024 15:25:01 -0700 Subject: [PATCH 04/46] Updated image to use code in this branch Reverted breaking change to env var --- applications/rag/frontend/main.tf | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index a88081926..5bac455d9 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -109,9 +109,7 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { spec { service_account_name = var.google_service_account container { - image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" - # Built from local code. Revert before submitting. - # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:108bb16ee2278255c80524fce125ef349c494cb5bc4ca77dbde5048b8f9448c1" + image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:335b60a0775abecd7bfcdde4bd051196d692949952aa3afb76fc934fc8d38842" name = "rag-frontend" port { @@ -126,8 +124,7 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { env { name = "PROJECT_ID" - #value = "projects/${var.project_id}" - value = var.project_id + value = "projects/${var.project_id}" } env { From 981e777048376c406379f2ac9a4baf5139a8d313 Mon Sep 17 00:00:00 2001 From: Julie Amundson Date: Mon, 6 May 2024 15:31:20 -0700 Subject: [PATCH 05/46] making tflint happy --- applications/rag/frontend/main.tf | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 5bac455d9..4b7f73254 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -110,7 +110,7 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { service_account_name = var.google_service_account container { image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:335b60a0775abecd7bfcdde4bd051196d692949952aa3afb76fc934fc8d38842" - name = "rag-frontend" + name = "rag-frontend" port { container_port = 8080 @@ -123,7 +123,7 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { } env { - name = "PROJECT_ID" + name = "PROJECT_ID" value = "projects/${var.project_id}" } From d1d1211787bc699244b46d4d1e98fbc089917a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Fri, 12 Jul 2024 15:50:56 -0500 Subject: [PATCH 06/46] Working on improvements for rag application (#731) * Working on improvements for rag application: - Working on missing TODO - Fixing issue with credentials - Refactoring vector_storages so you can add different vector storages TODO: Vector Storage factory - Unit test will be added on future PR * Updating changes with db * refactoring app so can be executed using gunicorn * refactory of the code as flask application package * Fixing Bugs - Reviewing issue with IPtypes, currently the fix is to validate if there's an development environment so a public cloud_sql instance can be use. - Fixing issue with Flask App Factory --- .../rag/frontend/container/Dockerfile | 7 +- .../rag/frontend/container/__init__.py | 13 +++ .../container/application/__init__.py | 25 ++++ .../{ => application}/cloud_sql/__init__.py | 0 .../application/cloud_sql/cloud_sql.py | 79 +++++++++++++ .../application/rag_langchain/__init__.py | 0 .../rag_langchain/rag_chain.py | 65 +++++------ .../{ => application}/rai/__init__.py | 0 .../{ => application}/rai/dlp_filter.py | 0 .../{ => application}/rai/nlp_filter.py | 0 .../container/{ => application}/rai/retry.py | 0 .../{ => application}/static/script.js | 0 .../{ => application}/static/styles.css | 0 .../{ => application}/templates/index.html | 0 .../application/vector_storages/__init__.py | 4 + .../application/vector_storages/cloud_sql.py | 69 +++++++++++ .../vector_storages/custom_vector_storage.py} | 107 +++++------------- applications/rag/frontend/container/main.py | 33 ++---- 18 files changed, 263 insertions(+), 139 deletions(-) create mode 100644 applications/rag/frontend/container/__init__.py create mode 100644 applications/rag/frontend/container/application/__init__.py rename applications/rag/frontend/container/{ => application}/cloud_sql/__init__.py (100%) create mode 100644 applications/rag/frontend/container/application/cloud_sql/cloud_sql.py create mode 100644 applications/rag/frontend/container/application/rag_langchain/__init__.py rename applications/rag/frontend/container/{ => application}/rag_langchain/rag_chain.py (61%) rename applications/rag/frontend/container/{ => application}/rai/__init__.py (100%) rename applications/rag/frontend/container/{ => application}/rai/dlp_filter.py (100%) rename applications/rag/frontend/container/{ => application}/rai/nlp_filter.py (100%) rename applications/rag/frontend/container/{ => application}/rai/retry.py (100%) rename applications/rag/frontend/container/{ => application}/static/script.js (100%) rename applications/rag/frontend/container/{ => application}/static/styles.css (100%) rename applications/rag/frontend/container/{ => application}/templates/index.html (100%) create mode 100644 applications/rag/frontend/container/application/vector_storages/__init__.py create mode 100644 applications/rag/frontend/container/application/vector_storages/cloud_sql.py rename applications/rag/frontend/container/{cloud_sql/cloud_sql.py => application/vector_storages/custom_vector_storage.py} (50%) diff --git a/applications/rag/frontend/container/Dockerfile b/applications/rag/frontend/container/Dockerfile index e333187a0..9d2f564a0 100644 --- a/applications/rag/frontend/container/Dockerfile +++ b/applications/rag/frontend/container/Dockerfile @@ -19,4 +19,9 @@ WORKDIR /workspace/frontend RUN pip install -r requirements.txt -CMD ["python", "main.py"] +EXPOSE 8080 + +ENV FLASK_APP=/workspace/frontend/main.py +ENV PYTHONPATH=. +# Run the application with Gunicorn +CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:8080", "main:app"] diff --git a/applications/rag/frontend/container/__init__.py b/applications/rag/frontend/container/__init__.py new file mode 100644 index 000000000..2d9f7b38b --- /dev/null +++ b/applications/rag/frontend/container/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/applications/rag/frontend/container/application/__init__.py b/applications/rag/frontend/container/application/__init__.py new file mode 100644 index 000000000..e732bf472 --- /dev/null +++ b/applications/rag/frontend/container/application/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from flask import Flask + +def create_app(): + app = Flask(__name__, static_folder='static', template_folder='templates') + app.jinja_env.trim_blocks = True + app.jinja_env.lstrip_blocks = True + app.config['SECRET_KEY'] = os.environ.get("SECRET_KEY") + + return app + diff --git a/applications/rag/frontend/container/cloud_sql/__init__.py b/applications/rag/frontend/container/application/cloud_sql/__init__.py similarity index 100% rename from applications/rag/frontend/container/cloud_sql/__init__.py rename to applications/rag/frontend/container/application/cloud_sql/__init__.py diff --git a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py new file mode 100644 index 000000000..135054178 --- /dev/null +++ b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py @@ -0,0 +1,79 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pymysql +import sqlalchemy +from google.cloud.sql.connector import Connector, IPTypes + +from langchain_google_cloud_sql_pg import PostgresEngine + +ENVIRONMENT= os.environ.get("ENVIRONMENT") + +GCP_PROJECT_ID= os.environ.get("PROJECT_ID") +GCP_CLOUD_SQL_REGION = os.environ.get("CLOUDSQL_INSTANCE_REGION") +GCP_CLOUD_SQL_INSTANCE = os.environ.get("CLOUDSQL_INSTANCE") + +INSTANCE_CONNECTION_NAME = f"{GCP_PROJECT_ID}:{GCP_CLOUD_SQL_REGION}:{GCP_CLOUD_SQL_INSTANCE}" + +DB_NAME = os.environ.get('DB_NAME', "pgvector-database") +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('EMBEDDINGS_TABLE_NAME', '') +CHAT_HISTORY_TABLE_NAME = os.environ.get('CHAT_HISTORY_TABLE_NAME', "message_store") + +try: + db_username_file = open("/etc/secret-volume/username", "r") + DB_USER = db_username_file.read() + db_username_file.close() + + db_password_file = open("/etc/secret-volume/password", "r") + DB_PASS = db_password_file.read() + db_password_file.close() +except: + DB_USER = os.environ.get("DB_USERNAME", "postgres") + DB_PASS = os.environ.get("DB_PASS", "postgres") + +# helper function to return SQLAlchemy connection pool +def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: + # function used to generate database connection + def getconn() -> pymysql.connections.Connection: + conn = connector.connect( + INSTANCE_CONNECTION_NAME, + "pg8000", + user=DB_USER, + password=DB_PASS, + db=DB_NAME, + ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE + ) + return conn + + # create connection pool + pool = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=getconn, + ) + return pool + +def create_sync_postgres_engine(): + engine = PostgresEngine.from_instance( + project_id=GCP_PROJECT_ID, + region=GCP_CLOUD_SQL_REGION, + instance=GCP_CLOUD_SQL_INSTANCE, + database=DB_NAME, + user=DB_USER, + password=DB_PASS, + ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE + ) + engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) + return engine diff --git a/applications/rag/frontend/container/application/rag_langchain/__init__.py b/applications/rag/frontend/container/application/rag_langchain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/rag/frontend/container/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py similarity index 61% rename from applications/rag/frontend/container/rag_langchain/rag_chain.py rename to applications/rag/frontend/container/application/rag_langchain/rag_chain.py index 807c875c8..445a8b1f4 100644 --- a/applications/rag/frontend/container/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -13,42 +13,47 @@ # limitations under the License. import os -from typing import (Dict) -from cloud_sql.cloud_sql import CHAT_HISTORY_TABLE_NAME, init_connection_pool, create_sync_postgres_engine, CustomVectorStore + from google.cloud.sql.connector import Connector -from langchain_community.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_core.prompts import PromptTemplate + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableParallel, RunnableLambda from langchain_core.runnables.history import RunnableWithMessageHistory + from langchain_google_cloud_sql_pg import PostgresChatMessageHistory +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint + +from application.cloud_sql.cloud_sql import CHAT_HISTORY_TABLE_NAME, init_connection_pool, create_sync_postgres_engine +from application.vector_storages import CustomVectorStore QUESTION = "question" HISTORY = "history" CONTEXT = "context" INFERENCE_ENDPOINT=os.environ.get('INFERENCE_ENDPOINT', '127.0.0.1:8081') +HUGGINGFACE_HUB_TOKEN = os.environ.get('HUGGINGFACE_HUB_TOKEN') + SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings +# TODO use a chat model instead of an LLM in the chain. -# TODO use a chat model instead of an LLM in the chain. Convert the prompt to a chat prompt template -# prompt = ChatPromptTemplate.from_messages( -# [ -# ("system", """You help everyone by answering questions, and improve your answers from previous answers in history. -# You stick to the facts by basing your answers off of the context provided:"""), -# MessagesPlaceholder(variable_name="history"), -# MessagesPlaceholder(variable_name="context"), -# ("human", "{question}"), -# ] -# ) -template = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. + +template_str = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. Improve upon your previous answers using History, a list of messages. Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. Stick to the facts by basing your answers off of the Context provided. Be brief in answering. -History: {""" + HISTORY + "}\n\nContext: {" + CONTEXT + "}\n\nQuestion: {" + QUESTION + "}\n" -prompt = PromptTemplate(template=template, input_variables=[HISTORY, CONTEXT, QUESTION]) +Question: {question} +Context: {context} +Answer:""" + +prompt = ChatPromptTemplate.from_messages([ + ("system",template_str), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), +]) engine = create_sync_postgres_engine() # TODO: Dict is not safe for multiprocessing. Introduce a cache using Flask-caching or libcache @@ -79,33 +84,15 @@ def clear_chat_history(session_id: str): # (as specified in hugging face TGI input parameter) def create_chain() -> RunnableWithMessageHistory: - # TODO HuggingFaceTextGenInference class is deprecated. - # The warning is: - # The class `langchain_community.llms.huggingface_text_gen_inference.HuggingFaceTextGenInference` - # was deprecated in langchain-community 0.0.21 and will be removed in 0.2.0. Use HuggingFaceEndpoint instead - # The replacement is HuggingFace Endoint, which requires a huggingface - # hub API token. Either need to add the token to the environment, or need to find a method to call TGI - # without the token. - # Example usage of HuggingFaceEndpoint: - # llm = HuggingFaceEndpoint( - # endpoint_url=f'http://{INFERENCE_ENDPOINT}/', - # max_new_tokens=512, - # top_k=10, - # top_p=0.95, - # typical_p=0.95, - # temperature=0.01, - # repetition_penalty=1.03, - # huggingfacehub_api_token="my-api-key" - # ) - # TODO: Give guidance on what these parameters should be and describe why these values were chosen. - model = HuggingFaceTextGenInference( - inference_server_url=f'http://{INFERENCE_ENDPOINT}/', + model = HuggingFaceEndpoint( + endpoint_url=f'http://{INFERENCE_ENDPOINT}/', max_new_tokens=512, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, + huggingfacehub_api_token=HUGGINGFACE_HUB_TOKEN, ) langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) diff --git a/applications/rag/frontend/container/rai/__init__.py b/applications/rag/frontend/container/application/rai/__init__.py similarity index 100% rename from applications/rag/frontend/container/rai/__init__.py rename to applications/rag/frontend/container/application/rai/__init__.py diff --git a/applications/rag/frontend/container/rai/dlp_filter.py b/applications/rag/frontend/container/application/rai/dlp_filter.py similarity index 100% rename from applications/rag/frontend/container/rai/dlp_filter.py rename to applications/rag/frontend/container/application/rai/dlp_filter.py diff --git a/applications/rag/frontend/container/rai/nlp_filter.py b/applications/rag/frontend/container/application/rai/nlp_filter.py similarity index 100% rename from applications/rag/frontend/container/rai/nlp_filter.py rename to applications/rag/frontend/container/application/rai/nlp_filter.py diff --git a/applications/rag/frontend/container/rai/retry.py b/applications/rag/frontend/container/application/rai/retry.py similarity index 100% rename from applications/rag/frontend/container/rai/retry.py rename to applications/rag/frontend/container/application/rai/retry.py diff --git a/applications/rag/frontend/container/static/script.js b/applications/rag/frontend/container/application/static/script.js similarity index 100% rename from applications/rag/frontend/container/static/script.js rename to applications/rag/frontend/container/application/static/script.js diff --git a/applications/rag/frontend/container/static/styles.css b/applications/rag/frontend/container/application/static/styles.css similarity index 100% rename from applications/rag/frontend/container/static/styles.css rename to applications/rag/frontend/container/application/static/styles.css diff --git a/applications/rag/frontend/container/templates/index.html b/applications/rag/frontend/container/application/templates/index.html similarity index 100% rename from applications/rag/frontend/container/templates/index.html rename to applications/rag/frontend/container/application/templates/index.html diff --git a/applications/rag/frontend/container/application/vector_storages/__init__.py b/applications/rag/frontend/container/application/vector_storages/__init__.py new file mode 100644 index 000000000..5a7989678 --- /dev/null +++ b/applications/rag/frontend/container/application/vector_storages/__init__.py @@ -0,0 +1,4 @@ +from .custom_vector_storage import CustomVectorStore + + +__all__ = ["CustomVectorStore"] \ No newline at end of file diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py new file mode 100644 index 000000000..3e7e487ba --- /dev/null +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -0,0 +1,69 @@ +import os +from typing import (List, Optional, Iterable, Any) + +import pg8000 +import sqlalchemy +from sqlalchemy.engine import Engine + +from langchain_core.vectorstores import VectorStore +from langchain_core.embeddings import Embeddings +from langchain_core.documents import Document + +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('TABLE_NAME', '') +INSTANCE_CONNECTION_NAME = os.environ.get('INSTANCE_CONNECTION_NAME', '') + +class CloudSQLVectorStorage(VectorStore): + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ): + raise NotImplementedError + + def __init__(self, embedding: Embeddings, engine: Engine): + self.embedding = embedding + self.engine = engine + + @property + def embeddings(self) -> Embeddings: + return self.embedding + + + # TODO implement + def add_texts(self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any) -> List[str]: + raise NotImplementedError + + #TODO implement similarity search with cosine similarity threshold + + def similarity_search(self, query: dict, k: int = 4, **kwargs: Any) -> List[Document]: + with self.engine.connect() as conn: + try: + q = query["question"] + # embed query & fetch matches + query_emb = self.embedding.embed_query(q) + emb_str = ",".join(map(str, query_emb)) + query_request = f"""SELECT id, text, 1 - ('[{emb_str}]' <=> text_embedding) AS cosine_similarity + FROM {VECTOR_EMBEDDINGS_TABLE_NAME} + ORDER BY cosine_similarity DESC LIMIT {k};""" + query_results = conn.execute(sqlalchemy.text(query_request)).fetchall() + print(f"GOT {len(query_results)} results") + conn.commit() + + if not query_results: + message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} returned empty result" + raise ValueError(message) + except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: + message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except sqlalchemy.exc.DatabaseError as err: + message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except Exception as err: + raise Exception(f"General error: {err}") + + #convert query results into List[Document] + texts = [result[1] for result in query_results] + return [Document(page_content=text) for text in texts] \ No newline at end of file diff --git a/applications/rag/frontend/container/cloud_sql/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/custom_vector_storage.py similarity index 50% rename from applications/rag/frontend/container/cloud_sql/cloud_sql.py rename to applications/rag/frontend/container/application/vector_storages/custom_vector_storage.py index 4a9186b38..01a41b8c7 100644 --- a/applications/rag/frontend/container/cloud_sql/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/custom_vector_storage.py @@ -1,89 +1,18 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import os from typing import (List, Optional, Iterable, Any) -from google.cloud.sql.connector import Connector, IPTypes -import pymysql +import pg8000 import sqlalchemy -from sentence_transformers import SentenceTransformer +from sqlalchemy.engine import Engine + from langchain_core.vectorstores import VectorStore -import pg8000 from langchain_core.embeddings import Embeddings from langchain_core.documents import Document -from sqlalchemy.engine import Engine -from langchain_google_cloud_sql_pg import PostgresEngine - -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('TABLE_NAME', '') # CloudSQL table name for vector embeddings -# TODO make this configurable from tf -CHAT_HISTORY_TABLE_NAME = "message_store" # CloudSQL table name where chat history is stored +from langchain.text_splitter import CharacterTextSplitter +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('TABLE_NAME', '') INSTANCE_CONNECTION_NAME = os.environ.get('INSTANCE_CONNECTION_NAME', '') -SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings -DB_NAME = "pgvector-database" - -PROJECT_ID = os.environ.get('PROJECT_ID', '') -REGION = os.environ.get('REGION', '') -INSTANCE = os.environ.get('INSTANCE', '') - -db_username_file = open("/etc/secret-volume/username", "r") -DB_USER = db_username_file.read() -db_username_file.close() - -db_password_file = open("/etc/secret-volume/password", "r") -DB_PASS = db_password_file.read() -db_password_file.close() - -transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) - -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: - # function used to generate database connection - def getconn() -> pymysql.connections.Connection: - conn = connector.connect( - INSTANCE_CONNECTION_NAME, - "pg8000", - user=DB_USER, - password=DB_PASS, - db=DB_NAME, - ip_type=IPTypes.PRIVATE - ) - return conn - - # create connection pool - pool = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, - ) - return pool - -def create_sync_postgres_engine(): - engine = PostgresEngine.from_instance( - project_id=PROJECT_ID, - region=REGION, - instance=INSTANCE, - database=DB_NAME, - user=DB_USER, - password=DB_PASS, - ip_type=IPTypes.PRIVATE - ) - engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) - return engine -#TODO replace this with the Cloud SQL vector store for langchain, -# once the notebook also uses it (and creates the correct schema) class CustomVectorStore(VectorStore): @classmethod def from_texts( @@ -98,7 +27,11 @@ def from_texts( def __init__(self, embedding: Embeddings, engine: Engine): self.embedding = embedding self.engine = engine - + self.text_splitter = CharacterTextSplitter( + separator="\n\n", + chunk_size=1024, + chunk_overlap=200, + ) @property def embeddings(self) -> Embeddings: return self.embedding @@ -106,7 +39,25 @@ def embeddings(self) -> Embeddings: # TODO implement def add_texts(self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any) -> List[str]: - raise NotImplementedError + with self.engine.connect() as conn: + try: + for raw_text in texts: + texts = self.text_splitter.split_text(raw_text) + + embeddings = self.embedding.encode(texts).tolist() + embeddings = embeddings.tobytes() + query_request = "INSERT INTO documents (text, embedding) VALUES (%s, %s)", + conn.execute(sqlalchemy.text(query_request),(texts, embeddings)) + conn.commit() + + except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: + message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except sqlalchemy.exc.DatabaseError as err: + message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except Exception as err: + raise Exception(f"General error: {err}") #TODO implement similarity search with cosine similarity threshold diff --git a/applications/rag/frontend/container/main.py b/applications/rag/frontend/container/main.py index 77a510d26..32c93e183 100644 --- a/applications/rag/frontend/container/main.py +++ b/applications/rag/frontend/container/main.py @@ -11,35 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os + import logging as log import google.cloud.logging as logging import traceback import uuid -from flask import Flask, render_template, request, jsonify, session -from rai import dlp_filter # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp -from rai import nlp_filter # https://cloud.google.com/natural-language/docs/moderating-text -from cloud_sql import cloud_sql -from rag_langchain.rag_chain import clear_chat_history, create_chain, take_chat_turn, engine +from flask import render_template, request, jsonify, session from datetime import datetime, timedelta, timezone +from application import create_app +from application.rai import dlp_filter # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp +from application.rai import nlp_filter # https://cloud.google.com/natural-language/docs/moderating-text + +from application.rag_langchain.rag_chain import clear_chat_history, create_chain, take_chat_turn + +SESSION_TIMEOUT_MINUTES = 30 + # Setup logging logging_client = logging.Client() logging_client.setup_logging() -# TODO: refactor the app startup code into a flask app factory -# TODO: include the chat history cache in the app lifecycle and ensure that it's threadsafe. -app = Flask(__name__, static_folder='static') -app.jinja_env.trim_blocks = True -app.jinja_env.lstrip_blocks = True -app.config['ENGINE'] = engine # force the connection pool to warm up eagerly - -SESSION_TIMEOUT_MINUTES = 30 -#TODO replace with real secret -SECRET_KEY = "TODO replace this with an actual secret that is stored and managed by kubernetes and added to the terraform configuration." -app.config['SECRET_KEY'] = SECRET_KEY +app = create_app() # Create llm chain llm_chain = create_chain() @@ -134,9 +128,6 @@ def handlePrompt(): }) response.status_code = 500 return response - - + if __name__ == '__main__': - # TODO using gunicorn to start the server results in the first request being really slow. - # Sometimes, the worker thread has to restart due to an unknown error. app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 8080))) \ No newline at end of file From 5a16b549a3a94c38af5cbc27a709fd520364fc2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 22 Jul 2024 10:29:32 -0500 Subject: [PATCH 07/46] Rag langchain chat history (#747) * Working on improvements for rag application: - Working on missing TODO - Fixing issue with credentials - Refactoring vector_storages so you can add different vector storages TODO: Vector Storage factory - Unit test will be added on future PR * Updating changes with db * refactoring app so can be executed using gunicorn * refactory of the code as flask application package * Fixing Bugs - Reviewing issue with IPtypes, currently the fix is to validate if there's an development environment so a public cloud_sql instance can be use. - Fixing issue with Flask App Factory * Working on Custom HuggingFace interface - Adding a custom chat model to send request to HuggingFace TGI API - Applying formatting to code. --- .../rag/frontend/container/__init__.py | 2 +- .../container/application/__init__.py | 1 - .../application/cloud_sql/__init__.py | 2 +- .../application/cloud_sql/cloud_sql.py | 70 +++++----- .../container/application/models/__init__.py | 17 +++ .../application/models/vector_embeddings.py | 31 ++++ .../application/rag_langchain/__init__.py | 13 ++ .../huggingface_inference_model.py | 131 +++++++++++++++++ .../application/rag_langchain/rag_chain.py | 102 +++++++------- .../container/application/rai/__init__.py | 3 +- .../container/application/rai/dlp_filter.py | 132 ++++++++++-------- .../container/application/rai/nlp_filter.py | 72 +++++----- .../container/application/rai/retry.py | 10 +- .../container/application/utils/__init__.py | 17 +++ .../utils/huggingface_tgi_helper.py | 34 +++++ .../application/vector_storages/__init__.py | 18 ++- .../application/vector_storages/cloud_sql.py | 113 ++++++++++++--- applications/rag/frontend/container/main.py | 113 +++++++++------ 18 files changed, 622 insertions(+), 259 deletions(-) create mode 100644 applications/rag/frontend/container/application/models/__init__.py create mode 100644 applications/rag/frontend/container/application/models/vector_embeddings.py create mode 100644 applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py create mode 100644 applications/rag/frontend/container/application/utils/__init__.py create mode 100644 applications/rag/frontend/container/application/utils/huggingface_tgi_helper.py diff --git a/applications/rag/frontend/container/__init__.py b/applications/rag/frontend/container/__init__.py index 2d9f7b38b..6d5e14bcf 100644 --- a/applications/rag/frontend/container/__init__.py +++ b/applications/rag/frontend/container/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/applications/rag/frontend/container/application/__init__.py b/applications/rag/frontend/container/application/__init__.py index e732bf472..64c71bccc 100644 --- a/applications/rag/frontend/container/application/__init__.py +++ b/applications/rag/frontend/container/application/__init__.py @@ -22,4 +22,3 @@ def create_app(): app.config['SECRET_KEY'] = os.environ.get("SECRET_KEY") return app - diff --git a/applications/rag/frontend/container/application/cloud_sql/__init__.py b/applications/rag/frontend/container/application/cloud_sql/__init__.py index efb24dbc0..11f30faf0 100644 --- a/applications/rag/frontend/container/application/cloud_sql/__init__.py +++ b/applications/rag/frontend/container/application/cloud_sql/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This file is required to make Python treat the subfolder as a package \ No newline at end of file +# This file is required to make Python treat the subfolder as a package diff --git a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py index 135054178..747ebe35a 100644 --- a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py +++ b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py @@ -20,60 +20,64 @@ from langchain_google_cloud_sql_pg import PostgresEngine -ENVIRONMENT= os.environ.get("ENVIRONMENT") +ENVIRONMENT = os.environ.get("ENVIRONMENT") -GCP_PROJECT_ID= os.environ.get("PROJECT_ID") +GCP_PROJECT_ID = os.environ.get("PROJECT_ID") GCP_CLOUD_SQL_REGION = os.environ.get("CLOUDSQL_INSTANCE_REGION") GCP_CLOUD_SQL_INSTANCE = os.environ.get("CLOUDSQL_INSTANCE") -INSTANCE_CONNECTION_NAME = f"{GCP_PROJECT_ID}:{GCP_CLOUD_SQL_REGION}:{GCP_CLOUD_SQL_INSTANCE}" +INSTANCE_CONNECTION_NAME = ( + f"{GCP_PROJECT_ID}:{GCP_CLOUD_SQL_REGION}:{GCP_CLOUD_SQL_INSTANCE}" +) -DB_NAME = os.environ.get('DB_NAME', "pgvector-database") -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('EMBEDDINGS_TABLE_NAME', '') -CHAT_HISTORY_TABLE_NAME = os.environ.get('CHAT_HISTORY_TABLE_NAME', "message_store") +DB_NAME = os.environ.get("DB_NAME", "pgvector-database") +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") +CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") try: - db_username_file = open("/etc/secret-volume/username", "r") - DB_USER = db_username_file.read() - db_username_file.close() + db_username_file = open("/etc/secret-volume/username", "r") + DB_USER = db_username_file.read() + db_username_file.close() - db_password_file = open("/etc/secret-volume/password", "r") - DB_PASS = db_password_file.read() - db_password_file.close() + db_password_file = open("/etc/secret-volume/password", "r") + DB_PASS = db_password_file.read() + db_password_file.close() except: - DB_USER = os.environ.get("DB_USERNAME", "postgres") - DB_PASS = os.environ.get("DB_PASS", "postgres") + DB_USER = os.environ.get("DB_USERNAME", "postgres") + DB_PASS = os.environ.get("DB_PASS", "postgres") + # helper function to return SQLAlchemy connection pool def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: - # function used to generate database connection - def getconn() -> pymysql.connections.Connection: - conn = connector.connect( - INSTANCE_CONNECTION_NAME, - "pg8000", - user=DB_USER, - password=DB_PASS, - db=DB_NAME, - ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE + # function used to generate database connection + def getconn() -> pymysql.connections.Connection: + conn = connector.connect( + INSTANCE_CONNECTION_NAME, + "pg8000", + user=DB_USER, + password=DB_PASS, + db=DB_NAME, + ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE, + ) + return conn + + # create connection pool + pool = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=getconn, ) - return conn - - # create connection pool - pool = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, - ) - return pool + return pool + def create_sync_postgres_engine(): engine = PostgresEngine.from_instance( project_id=GCP_PROJECT_ID, - region=GCP_CLOUD_SQL_REGION, + region=GCP_CLOUD_SQL_REGION, instance=GCP_CLOUD_SQL_INSTANCE, database=DB_NAME, user=DB_USER, password=DB_PASS, - ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE + ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE, ) engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) return engine diff --git a/applications/rag/frontend/container/application/models/__init__.py b/applications/rag/frontend/container/application/models/__init__.py new file mode 100644 index 000000000..8d1d56862 --- /dev/null +++ b/applications/rag/frontend/container/application/models/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vector_embeddings import VectorEmbeddings + +__all__ = ["VectorEmbeddings"] diff --git a/applications/rag/frontend/container/application/models/vector_embeddings.py b/applications/rag/frontend/container/application/models/vector_embeddings.py new file mode 100644 index 000000000..fd42ea5b2 --- /dev/null +++ b/applications/rag/frontend/container/application/models/vector_embeddings.py @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from sqlalchemy import Column, String, Text +from sqlalchemy.orm import mapped_column, declarative_base +from pgvector.sqlalchemy import Vector + +Base = declarative_base() + +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") + + +class VectorEmbeddings(Base): + __tablename__ = VECTOR_EMBEDDINGS_TABLE_NAME + + id = Column(String(255), primary_key=True) + text = Column(Text) + text_embedding = mapped_column(Vector(384)) diff --git a/applications/rag/frontend/container/application/rag_langchain/__init__.py b/applications/rag/frontend/container/application/rag_langchain/__init__.py index e69de29bb..6d5e14bcf 100644 --- a/applications/rag/frontend/container/application/rag_langchain/__init__.py +++ b/applications/rag/frontend/container/application/rag_langchain/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py b/applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py new file mode 100644 index 000000000..2902270aa --- /dev/null +++ b/applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py @@ -0,0 +1,131 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +from typing import Any, Dict, Iterator, List, Optional + +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk + +from application.utils import post_request + +INFERENCE_ENDPOINT = os.environ.get("INFERENCE_ENDPOINT", "127.0.0.1:8081") + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class HuggingFaceCustomChatModel(LLM): + """A custom chat model that calls to an HuggingFace TGI API and returns the generated + content based on the given message. + + Example: + + .. code-block:: python + + model = HuggingFaceCustomChatModel() + result = model.invoke([HumanMessage(content="hello")]) + result = model.batch([[HumanMessage(content="hello")], + [HumanMessage(content="world")]]) + """ + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Run the LLM on the given input. + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + If stop tokens are not supported consider raising NotImplementedError. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output as a string. + """ + if stop is not None: + raise ValueError("stop kwargs are not permitted.") + + api_endpoint = f"http://{INFERENCE_ENDPOINT}/generate" + body = {"inputs": prompt} + headers = {"Content-Type": "application/json"} + generated_output = post_request(api_endpoint, body, headers) + generated_text = generated_output.get("generated_text", "") + + return generated_text + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Stream the LLM on the given prompt. + + This method should be overridden by subclasses that support streaming. + + If not implemented, the default behavior of calls to stream will be to + fallback to the non-streaming version of the model and return + the output as a single chunk. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + An iterator of GenerationChunks. + """ + api_endpoint = f"http://{INFERENCE_ENDPOINT}/generate_stream" + body = {"inputs": prompt} + headers = {"Content-Type": "application/json"} + logging.info("Calling external model") + generated_output = post_request(api_endpoint, body, headers) + generated_text = generated_output.get("generated_text", "") + + for char in generated_text: + chunk = GenerationChunk(text=char) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + yield chunk + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Return a dictionary of identifying parameters.""" + return { + # The model name allows users to specify custom token counting + # rules in LLM monitoring applications (e.g., in LangSmith users + # can provide per token pricing for their model and monitor + # costs for the given LLM.) + "model_name": "HuggingFaceTGI", + } + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model. Used for logging purposes only.""" + return "custom" diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index 445a8b1f4..4972375fd 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import logging from google.cloud.sql.connector import Connector @@ -22,88 +22,81 @@ from langchain_google_cloud_sql_pg import PostgresChatMessageHistory from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint -from application.cloud_sql.cloud_sql import CHAT_HISTORY_TABLE_NAME, init_connection_pool, create_sync_postgres_engine -from application.vector_storages import CustomVectorStore - -QUESTION = "question" -HISTORY = "history" +from application.cloud_sql.cloud_sql import ( + CHAT_HISTORY_TABLE_NAME, + init_connection_pool, + create_sync_postgres_engine, +) +from application.rag_langchain.huggingface_inference_model import ( + HuggingFaceCustomChatModel, +) +from application.vector_storages import CloudSQLVectorStore + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +QUESTION = "input" +HISTORY = "chat_history" CONTEXT = "context" -INFERENCE_ENDPOINT=os.environ.get('INFERENCE_ENDPOINT', '127.0.0.1:8081') -HUGGINGFACE_HUB_TOKEN = os.environ.get('HUGGINGFACE_HUB_TOKEN') - -SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings - -# TODO use a chat model instead of an LLM in the chain. - +SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings template_str = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. Improve upon your previous answers using History, a list of messages. Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. Stick to the facts by basing your answers off of the Context provided. Be brief in answering. - -Question: {question} +\n\n Context: {context} -Answer:""" +""" -prompt = ChatPromptTemplate.from_messages([ - ("system",template_str), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), -]) +prompt = ChatPromptTemplate.from_messages( + [ + ("system", template_str), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) engine = create_sync_postgres_engine() # TODO: Dict is not safe for multiprocessing. Introduce a cache using Flask-caching or libcache -# The in-memory SimpleCache implementations for each of these libraries is not safe either. +# The in-memory SimpleCache implementations for each of these libraries is not safe either. # Consider redis or memcached (e.g., Memorystore) # chat_history_map: Dict[str, PostgresChatMessageHistory] = {} + def get_chat_history(session_id: str) -> PostgresChatMessageHistory: history = PostgresChatMessageHistory.create_sync( - engine, - session_id=session_id, - table_name = CHAT_HISTORY_TABLE_NAME + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME ) print(f"Retrieving history for session {session_id} with {len(history.messages)}") return history + def clear_chat_history(session_id: str): history = PostgresChatMessageHistory.create_sync( - engine, - session_id=session_id, - table_name = CHAT_HISTORY_TABLE_NAME - ) + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) history.clear() -#TODO: limit number of tokens in prompt to MAX_INPUT_LENGTH -# (as specified in hugging face TGI input parameter) - -def create_chain() -> RunnableWithMessageHistory: - model = HuggingFaceEndpoint( - endpoint_url=f'http://{INFERENCE_ENDPOINT}/', - max_new_tokens=512, - top_k=10, - top_p=0.95, - typical_p=0.95, - temperature=0.01, - repetition_penalty=1.03, - huggingfacehub_api_token=HUGGINGFACE_HUB_TOKEN, - ) +def create_chain() -> RunnableWithMessageHistory: + model = HuggingFaceCustomChatModel() langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CustomVectorStore(langchain_embed, init_connection_pool(Connector())) + vector_store = CloudSQLVectorStore( + langchain_embed, init_connection_pool(Connector()) + ) retriever = vector_store.as_retriever() setup_and_retrieval = RunnableParallel( { - "context": retriever, - QUESTION: RunnableLambda(lambda d: d[QUESTION]), - HISTORY: RunnableLambda(lambda d: d[HISTORY]) + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]), } ) chain = setup_and_retrieval | prompt | model @@ -112,12 +105,15 @@ def create_chain() -> RunnableWithMessageHistory: get_chat_history, input_messages_key=QUESTION, history_messages_key=HISTORY, - output_messages_key="output" + output_messages_key="output", ) return chain_with_history -def take_chat_turn(chain: RunnableWithMessageHistory, session_id: str, query_text: str) -> str: - #TODO limit the number of history messages + +def take_chat_turn( + chain: RunnableWithMessageHistory, session_id: str, query_text: str +) -> str: + # TODO limit the number of history messages config = {"configurable": {"session_id": session_id}} - result = chain.invoke({"question": query_text}, config) - return str(result) \ No newline at end of file + result = chain.invoke({"input": query_text}, config=config) + return str(result) diff --git a/applications/rag/frontend/container/application/rai/__init__.py b/applications/rag/frontend/container/application/rai/__init__.py index efb24dbc0..cb97ceb52 100644 --- a/applications/rag/frontend/container/application/rai/__init__.py +++ b/applications/rag/frontend/container/application/rai/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This file is required to make Python treat the subfolder as a package \ No newline at end of file +# This file is required to make Python treat the subfolder as a package + diff --git a/applications/rag/frontend/container/application/rai/dlp_filter.py b/applications/rag/frontend/container/application/rai/dlp_filter.py index c7e24d206..4a22b976f 100644 --- a/applications/rag/frontend/container/application/rai/dlp_filter.py +++ b/applications/rag/frontend/container/application/rai/dlp_filter.py @@ -17,87 +17,95 @@ from . import retry # Convert the project id into a full resource id. -parent = os.environ.get('PROJECT_ID', 'NULL') +parent = os.environ.get("PROJECT_ID", "NULL") # Instantiate a dlp client. dlp_client = google.cloud.dlp_v2.DlpServiceClient() + def is_dlp_api_enabled(): - if parent == 'NULL': - return False - # Check if the DLP API is enabled - try: - dlp_client.list_info_types( - request={"parent": 'en-US'}, retry=retry.retry_policy - ) - return True - except Exception as e: - print(f"Error: {e}") - return False + if parent == "NULL": + return False + # Check if the DLP API is enabled + try: + dlp_client.list_info_types( + request={"parent": "en-US"}, retry=retry.retry_policy + ) + return True + except Exception as e: + print(f"Error: {e}") + return False + def list_inspect_templates_from_parent(): - # Initialize request argument(s) - request = google.cloud.dlp_v2.ListInspectTemplatesRequest( - parent=parent, - ) + # Initialize request argument(s) + request = google.cloud.dlp_v2.ListInspectTemplatesRequest( + parent=parent, + ) - # Make the request - page_result = dlp_client.list_inspect_templates(request=request, retry=retry.retry_policy) + # Make the request + page_result = dlp_client.list_inspect_templates( + request=request, retry=retry.retry_policy + ) + + name_list = [] + # Handle the response + for response in page_result: + name_list.append(response.name) + return name_list - name_list = [] - # Handle the response - for response in page_result: - name_list.append(response.name) - return name_list def get_inspect_templates_from_name(name): - request = google.cloud.dlp_v2.GetInspectTemplateRequest( - name=name, - ) + request = google.cloud.dlp_v2.GetInspectTemplateRequest( + name=name, + ) - return dlp_client.get_inspect_template(request=request) + return dlp_client.get_inspect_template(request=request) def list_deidentify_templates_from_parent(): - # Initialize request argument(s) - request = google.cloud.dlp_v2.ListDeidentifyTemplatesRequest( - parent=parent, - ) + # Initialize request argument(s) + request = google.cloud.dlp_v2.ListDeidentifyTemplatesRequest( + parent=parent, + ) + + # Make the request + page_result = dlp_client.list_deidentify_templates(request=request) - # Make the request - page_result = dlp_client.list_deidentify_templates(request=request) + name_list = [] + # Handle the response + for response in page_result: + name_list.append(response.name) + return name_list - name_list = [] - # Handle the response - for response in page_result: - name_list.append(response.name) - return name_list def get_deidentify_templates_from_name(name): - request = google.cloud.dlp_v2.GetDeidentifyTemplateRequest( - name=name, - ) + request = google.cloud.dlp_v2.GetDeidentifyTemplateRequest( + name=name, + ) + + return dlp_client.get_deidentify_template(request=request, retry=retry.retry_policy) - return dlp_client.get_deidentify_template(request=request, retry=retry.retry_policy) def inspect_content(inspect_template_path, deidentify_template_path, input): - inspect_templates = get_inspect_templates_from_name(inspect_template_path) - deidentify_template = get_deidentify_templates_from_name(deidentify_template_path) - - - # Construct item - item = {"value": input} - - # Call the API - response = dlp_client.deidentify_content( - request={ - "parent": parent, - "deidentify_config": deidentify_template.deidentify_config, - "inspect_config": inspect_templates.inspect_config, - "item": item, - }, retry=retry.retry_policy - ) - - # Print out the results. - print(response.item.value) - return response.item.value \ No newline at end of file + inspect_templates = get_inspect_templates_from_name(inspect_template_path) + deidentify_template = get_deidentify_templates_from_name(deidentify_template_path) + + # Construct item + item = {"value": input} + + # Call the API + response = dlp_client.deidentify_content( + request={ + "parent": parent, + "deidentify_config": deidentify_template.deidentify_config, + "inspect_config": inspect_templates.inspect_config, + "item": item, + }, + retry=retry.retry_policy, + ) + + # Print out the results. + print(response.item.value) + return response.item.value + diff --git a/applications/rag/frontend/container/application/rai/nlp_filter.py b/applications/rag/frontend/container/application/rai/nlp_filter.py index ca5c91649..0fbf2688e 100644 --- a/applications/rag/frontend/container/application/rai/nlp_filter.py +++ b/applications/rag/frontend/container/application/rai/nlp_filter.py @@ -17,46 +17,50 @@ from . import retry # Convert the project id into a full resource id. -parent = os.environ.get('PROJECT_ID', 'NULL') +parent = os.environ.get("PROJECT_ID", "NULL") # Instantiate a nlp client. nature_language_client = language.LanguageServiceClient() + def is_nlp_api_enabled(): - if parent == 'NULL': - return False - # Check if the DLP API is enabled - try: - sum_moderation_confidences("test") - return True - except Exception as e: - print(f"Error: {e}") - return False + if parent == "NULL": + return False + # Check if the DLP API is enabled + try: + sum_moderation_confidences("test") + return True + except Exception as e: + print(f"Error: {e}") + return False + def sum_moderation_confidences(text): - document = language.types.Document( - content=text, type_=language.types.Document.Type.PLAIN_TEXT - ) - - request = language.ModerateTextRequest( - document=document, - ) - # Detects the sentiment of the text - response = nature_language_client.moderate_text( - request=request, retry=retry.retry_policy - ) - print(f'get response: {response}') - # Parse response and sum the confidences of moderation, the categories are from https://cloud.google.com/natural-language/docs/moderating-text - largest_confidence = 0.0 - excluding_names = ["Health", "Politics", "Finance", "Legal"] - for category in response.moderation_categories: - if category.name in excluding_names: - continue - if category.confidence > largest_confidence: - largest_confidence = category.confidence - - print(f'largest confidence is: {largest_confidence}') - return int(largest_confidence * 100) + document = language.types.Document( + content=text, type_=language.types.Document.Type.PLAIN_TEXT + ) + + request = language.ModerateTextRequest( + document=document, + ) + # Detects the sentiment of the text + response = nature_language_client.moderate_text( + request=request, retry=retry.retry_policy + ) + print(f"get response: {response}") + # Parse response and sum the confidences of moderation, the categories are from https://cloud.google.com/natural-language/docs/moderating-text + largest_confidence = 0.0 + excluding_names = ["Health", "Politics", "Finance", "Legal"] + for category in response.moderation_categories: + if category.name in excluding_names: + continue + if category.confidence > largest_confidence: + largest_confidence = category.confidence + + print(f"largest confidence is: {largest_confidence}") + return int(largest_confidence * 100) + def is_content_inappropriate(text, nlp_filter_level): - return sum_moderation_confidences(text) > (100 - int(nlp_filter_level)) + return sum_moderation_confidences(text) > (100 - int(nlp_filter_level)) + diff --git a/applications/rag/frontend/container/application/rai/retry.py b/applications/rag/frontend/container/application/rai/retry.py index e3901ac22..b9092b683 100644 --- a/applications/rag/frontend/container/application/rai/retry.py +++ b/applications/rag/frontend/container/application/rai/retry.py @@ -17,13 +17,15 @@ _RETRIABLE_TYPES = [ -exceptions.TooManyRequests, # 429 -exceptions.InternalServerError, # 500 -exceptions.BadGateway, # 502 -exceptions.ServiceUnavailable, # 503 + exceptions.TooManyRequests, # 429 + exceptions.InternalServerError, # 500 + exceptions.BadGateway, # 502 + exceptions.ServiceUnavailable, # 503 ] + def is_retryable(exc): return isinstance(exc, _RETRIABLE_TYPES) + retry_policy = Retry(predicate=is_retryable) \ No newline at end of file diff --git a/applications/rag/frontend/container/application/utils/__init__.py b/applications/rag/frontend/container/application/utils/__init__.py new file mode 100644 index 000000000..8034550de --- /dev/null +++ b/applications/rag/frontend/container/application/utils/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .huggingface_tgi_helper import post_request + + +__all__ = ["post_request"] diff --git a/applications/rag/frontend/container/application/utils/huggingface_tgi_helper.py b/applications/rag/frontend/container/application/utils/huggingface_tgi_helper.py new file mode 100644 index 000000000..0a7d3f6b4 --- /dev/null +++ b/applications/rag/frontend/container/application/utils/huggingface_tgi_helper.py @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import requests + + +def post_request(endpoint, body_params, headers): + """ + Perform a POST request to a given endpoint with specified body parameters and headers. + + Args: + endpoint (str): The URL endpoint for the POST request. + body_params (dict): The body parameters to be sent in the POST request. + headers (dict): The headers to be included in the POST request. + + Returns: + dict: The response from the POST request. + """ + try: + response = requests.post(endpoint, json=body_params, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + return {"error": str(e)} diff --git a/applications/rag/frontend/container/application/vector_storages/__init__.py b/applications/rag/frontend/container/application/vector_storages/__init__.py index 5a7989678..94cde2d79 100644 --- a/applications/rag/frontend/container/application/vector_storages/__init__.py +++ b/applications/rag/frontend/container/application/vector_storages/__init__.py @@ -1,4 +1,18 @@ -from .custom_vector_storage import CustomVectorStore +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .cloud_sql import CloudSQLVectorStore -__all__ = ["CustomVectorStore"] \ No newline at end of file + +__all__ = ["CloudSQLVectorStore"] diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py index 3e7e487ba..545f72259 100644 --- a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -1,18 +1,45 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os -from typing import (List, Optional, Iterable, Any) +import uuid +import logging + +from typing import List, Optional, Iterable, Any import pg8000 import sqlalchemy from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.sql import func, text from langchain_core.vectorstores import VectorStore from langchain_core.embeddings import Embeddings from langchain_core.documents import Document +from langchain.text_splitter import CharacterTextSplitter + +from application.models import VectorEmbeddings + +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") +INSTANCE_CONNECTION_NAME = os.environ.get("INSTANCE_CONNECTION_NAME", "") -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('TABLE_NAME', '') -INSTANCE_CONNECTION_NAME = os.environ.get('INSTANCE_CONNECTION_NAME', '') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) -class CloudSQLVectorStorage(VectorStore): + +class CloudSQLVectorStore(VectorStore): @classmethod def from_texts( cls, @@ -26,36 +53,80 @@ def from_texts( def __init__(self, embedding: Embeddings, engine: Engine): self.embedding = embedding self.engine = engine + self.text_splitter = CharacterTextSplitter( + separator="\n\n", + chunk_size=1024, + chunk_overlap=200, + ) @property def embeddings(self) -> Embeddings: return self.embedding - # TODO implement - def add_texts(self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any) -> List[str]: - raise NotImplementedError - - #TODO implement similarity search with cosine similarity threshold + def add_texts( + self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any + ) -> List[str]: + with self.engine.connect() as conn: + try: + Session = sessionmaker(bind=conn) + session = Session(bind=conn) + for raw_text in texts: + id = uuid.uuid4() + + texts = self.text_splitter.split_text(raw_text) + embeddings = self.embedding.encode(texts).tolist() + vector_embedding = VectorEmbeddings( + id=id, text=texts, text_embedding=embeddings[0] + ) + session.add(vector_embedding) + conn.commit() - def similarity_search(self, query: dict, k: int = 4, **kwargs: Any) -> List[Document]: + except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: + message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except sqlalchemy.exc.DatabaseError as err: + message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" + raise sqlalchemy.exc.DataError(message) + except Exception as err: + raise Exception(f"General error: {err}") + + # TODO implement similarity search with cosine similarity threshold + + def similarity_search( + self, query: dict, k: int = 4, **kwargs: Any + ) -> List[Document]: with self.engine.connect() as conn: try: - q = query["question"] + Session = sessionmaker(bind=conn) + session = Session(bind=conn) + + q = query.get("input") # embed query & fetch matches query_emb = self.embedding.embed_query(q) - emb_str = ",".join(map(str, query_emb)) - query_request = f"""SELECT id, text, 1 - ('[{emb_str}]' <=> text_embedding) AS cosine_similarity - FROM {VECTOR_EMBEDDINGS_TABLE_NAME} - ORDER BY cosine_similarity DESC LIMIT {k};""" - query_results = conn.execute(sqlalchemy.text(query_request)).fetchall() + query_request = ( + "SELECT id, text, text_embedding, 1 - ('[" + + ",".join(map(str, query_emb)) + + "]' <=> text_embedding) AS cosine_similarity FROM " + + VECTOR_EMBEDDINGS_TABLE_NAME + + " ORDER BY cosine_similarity DESC LIMIT " + + str(k) + + ";" + ) + query_results = session.execute(text(query_request)).fetchall() + print(f"GOT {len(query_results)} results") - conn.commit() + + session.commit() + session.close() if not query_results: - message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} returned empty result" + message = ( + f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} returned empty result" + ) raise ValueError(message) - except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: + + except sqlalchemy.exc.DataError or pg8000.exceptions.DatabaseError as err: message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" raise sqlalchemy.exc.DataError(message) except sqlalchemy.exc.DatabaseError as err: @@ -64,6 +135,6 @@ def similarity_search(self, query: dict, k: int = 4, **kwargs: Any) -> List[Docu except Exception as err: raise Exception(f"General error: {err}") - #convert query results into List[Document] + # convert query results into List[Document] texts = [result[1] for result in query_results] - return [Document(page_content=text) for text in texts] \ No newline at end of file + return [Document(page_content=text) for text in texts] diff --git a/applications/rag/frontend/container/main.py b/applications/rag/frontend/container/main.py index 32c93e183..b029451d6 100644 --- a/applications/rag/frontend/container/main.py +++ b/applications/rag/frontend/container/main.py @@ -1,6 +1,3 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -11,21 +8,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +import os +import uuid +import traceback import logging as log + import google.cloud.logging as logging -import traceback -import uuid from flask import render_template, request, jsonify, session from datetime import datetime, timedelta, timezone from application import create_app -from application.rai import dlp_filter # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp -from application.rai import nlp_filter # https://cloud.google.com/natural-language/docs/moderating-text +from application.rai import ( + dlp_filter, +) # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp +from application.rai import ( + nlp_filter, +) # https://cloud.google.com/natural-language/docs/moderating-text + +from application.rag_langchain.rag_chain import ( + clear_chat_history, + create_chain, + take_chat_turn, +) -from application.rag_langchain.rag_chain import clear_chat_history, create_chain, take_chat_turn +log.basicConfig(level=log.INFO, format="%(asctime)s - %(levelname)s - %(message)s") SESSION_TIMEOUT_MINUTES = 30 @@ -38,96 +46,109 @@ # Create llm chain llm_chain = create_chain() -@app.route('/get_nlp_status', methods=['GET']) + +@app.route("/get_nlp_status", methods=["GET"]) def get_nlp_status(): nlp_enabled = nlp_filter.is_nlp_api_enabled() return jsonify({"nlpEnabled": nlp_enabled}) -@app.route('/get_dlp_status', methods=['GET']) + +@app.route("/get_dlp_status", methods=["GET"]) def get_dlp_status(): dlp_enabled = dlp_filter.is_dlp_api_enabled() return jsonify({"dlpEnabled": dlp_enabled}) -@app.route('/get_inspect_templates') + +@app.route("/get_inspect_templates") def get_inspect_templates(): return jsonify(dlp_filter.list_inspect_templates_from_parent()) -@app.route('/get_deidentify_templates') + +@app.route("/get_deidentify_templates") def get_deidentify_templates(): return jsonify(dlp_filter.list_deidentify_templates_from_parent()) + @app.before_request def check_new_session(): - if 'session_id' not in session: + if "session_id" not in session: # instantiate a new session using a generated UUID session_id = str(uuid.uuid4()) - session['session_id'] = session_id + session["session_id"] = session_id + @app.before_request def check_inactivity(): # Inactivity cleanup - if 'last_activity' in session: - time_elapsed = datetime.now(timezone.utc) - session['last_activity'] + if "last_activity" in session: + time_elapsed = datetime.now(timezone.utc) - session["last_activity"] if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): print("Session inactive: Cleaning up resources...") - session_id = session['session_id'] + session_id = session["session_id"] # TODO: implement garbage collection process for idle sessions that have timed out clear_chat_history(session_id) session.clear() # Always update the 'last_activity' data - session['last_activity'] = datetime.now(timezone.utc) + session["last_activity"] = datetime.now(timezone.utc) -@app.route('/') + +@app.route("/") def index(): - return render_template('index.html') + return render_template("index.html") + -@app.route('/prompt', methods=['POST']) +@app.route("/prompt", methods=["POST"]) def handlePrompt(): # TODO on page refresh, load chat history into browser. - session['last_activity'] = datetime.now(timezone.utc) + session["last_activity"] = datetime.now(timezone.utc) data = request.get_json() warnings = [] - if 'prompt' not in data: - return 'missing required prompt', 400 + if "prompt" not in data: + return "missing required prompt", 400 - user_prompt = data['prompt'] + user_prompt = data["prompt"] log.info(f"handle user prompt: {user_prompt}") try: response = {} - result = take_chat_turn(llm_chain, session['session_id'], user_prompt) - response['text'] = result + result = take_chat_turn(llm_chain, session["session_id"], user_prompt) + log.info("After the result") + log.info(result) + response["text"] = result # TODO: enable filtering in chain - if 'nlpFilterLevel' in data: - if nlp_filter.is_content_inappropriate(response['text'], data['nlpFilterLevel']): - response['text'] = 'The response is deemed inappropriate for display.' - return {'response': response} - if 'inspectTemplate' in data and 'deidentifyTemplate' in data: - inspect_template_path = data['inspectTemplate'] - deidentify_template_path = data['deidentifyTemplate'] + if "nlpFilterLevel" in data: + if nlp_filter.is_content_inappropriate( + response["text"], data["nlpFilterLevel"] + ): + response["text"] = "The response is deemed inappropriate for display." + return {"response": response} + if "inspectTemplate" in data and "deidentifyTemplate" in data: + inspect_template_path = data["inspectTemplate"] + deidentify_template_path = data["deidentifyTemplate"] if inspect_template_path != "" and deidentify_template_path != "": # filter the output with inspect setting. Customer can pick any category from https://cloud.google.com/dlp/docs/concepts-infotypes - response['text'] = dlp_filter.inspect_content(inspect_template_path, deidentify_template_path, response['text']) + response["text"] = dlp_filter.inspect_content( + inspect_template_path, deidentify_template_path, response["text"] + ) if warnings: - response['warnings'] = warnings + response["warnings"] = warnings log.info(f"response: {response}") - return {'response': response} + return {"response": response} except Exception as err: log.info(f"exception from llm: {err}") traceback.print_exc() error_traceback = traceback.format_exc() - response = jsonify({ - "warnings": warnings, - "error": "An error occurred", - "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}" - }) + response = jsonify( + { + "warnings": warnings, + "error": "An error occurred", + "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", + } + ) response.status_code = 500 return response - -if __name__ == '__main__': - app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 8080))) \ No newline at end of file From 9e416c878ed32764cc9f5ef44891d25ee4583aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 29 Jul 2024 15:57:54 -0500 Subject: [PATCH 08/46] Rag langchain chat history (#755) * Working on improvements for rag application: - Working on missing TODO - Fixing issue with credentials - Refactoring vector_storages so you can add different vector storages TODO: Vector Storage factory - Unit test will be added on future PR * Updating changes with db * refactoring app so can be executed using gunicorn * refactory of the code as flask application package * Fixing Bugs - Reviewing issue with IPtypes, currently the fix is to validate if there's an development environment so a public cloud_sql instance can be use. - Fixing issue with Flask App Factory * Working on Custom HuggingFace interface - Adding a custom chat model to send request to HuggingFace TGI API - Applying formatting to code. * Improving the CloudSQL vector vector_storage --- .../application/cloud_sql/cloud_sql.py | 63 ++++++---- .../application/rag_langchain/rag_chain.py | 24 ++-- .../container/application/rai/__init__.py | 1 - .../container/application/rai/retry.py | 3 +- .../application/vector_storages/__init__.py | 1 + .../application/vector_storages/cloud_sql.py | 116 +++++------------- applications/rag/frontend/container/main.py | 80 ++++++++---- .../rag/frontend/container/requirements.txt | 1 - 8 files changed, 135 insertions(+), 154 deletions(-) diff --git a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py index 747ebe35a..cf3c4d7ed 100644 --- a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py +++ b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py @@ -5,6 +5,13 @@ # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +20,15 @@ # limitations under the License. import os +import logging + +from google.cloud.sql.connector import IPTypes -import pymysql -import sqlalchemy -from google.cloud.sql.connector import Connector, IPTypes +from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore -from langchain_google_cloud_sql_pg import PostgresEngine +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) ENVIRONMENT = os.environ.get("ENVIRONMENT") @@ -34,6 +44,8 @@ VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") +VECTOR_DIMENSION = os.environ.get("VECTOR_DIMENSION", 384) + try: db_username_file = open("/etc/secret-volume/username", "r") DB_USER = db_username_file.read() @@ -47,28 +59,6 @@ DB_PASS = os.environ.get("DB_PASS", "postgres") -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: - # function used to generate database connection - def getconn() -> pymysql.connections.Connection: - conn = connector.connect( - INSTANCE_CONNECTION_NAME, - "pg8000", - user=DB_USER, - password=DB_PASS, - db=DB_NAME, - ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE, - ) - return conn - - # create connection pool - pool = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, - ) - return pool - - def create_sync_postgres_engine(): engine = PostgresEngine.from_instance( project_id=GCP_PROJECT_ID, @@ -79,5 +69,24 @@ def create_sync_postgres_engine(): password=DB_PASS, ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE, ) - engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) + try: + engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) + engine.init_vectorstore_table( + VECTOR_EMBEDDINGS_TABLE_NAME, + vector_size=VECTOR_DIMENSION, + overwrite_existing=False, + ) + except Exception as e: + logging.info(f"Error: {e}") + return engine + + +def create_sync_postgres_vector_store(engine, embedding_provider): + vector_store = PostgresVectorStore.create_sync( + engine=engine, + embedding_service=embedding_provider, + table_name=VECTOR_EMBEDDINGS_TABLE_NAME, + ) + + return vector_store diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index 4972375fd..8e9a80999 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -14,18 +14,16 @@ import logging -from google.cloud.sql.connector import Connector - from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableParallel, RunnableLambda from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_google_cloud_sql_pg import PostgresChatMessageHistory from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_google_cloud_sql_pg import PostgresChatMessageHistory + from application.cloud_sql.cloud_sql import ( CHAT_HISTORY_TABLE_NAME, - init_connection_pool, create_sync_postgres_engine, ) from application.rag_langchain.huggingface_inference_model import ( @@ -61,10 +59,6 @@ ) engine = create_sync_postgres_engine() -# TODO: Dict is not safe for multiprocessing. Introduce a cache using Flask-caching or libcache -# The in-memory SimpleCache implementations for each of these libraries is not safe either. -# Consider redis or memcached (e.g., Memorystore) -# chat_history_map: Dict[str, PostgresChatMessageHistory] = {} def get_chat_history(session_id: str) -> PostgresChatMessageHistory: @@ -72,7 +66,9 @@ def get_chat_history(session_id: str) -> PostgresChatMessageHistory: engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME ) - print(f"Retrieving history for session {session_id} with {len(history.messages)}") + logging.info( + f"Retrieving history for session {session_id} with {len(history.messages)}" + ) return history @@ -87,9 +83,8 @@ def create_chain() -> RunnableWithMessageHistory: model = HuggingFaceCustomChatModel() langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CloudSQLVectorStore( - langchain_embed, init_connection_pool(Connector()) - ) + vector_store = CloudSQLVectorStore(langchain_embed, engine) + retriever = vector_store.as_retriever() setup_and_retrieval = RunnableParallel( @@ -99,6 +94,7 @@ def create_chain() -> RunnableWithMessageHistory: HISTORY: RunnableLambda(lambda d: d[HISTORY]), } ) + chain = setup_and_retrieval | prompt | model chain_with_history = RunnableWithMessageHistory( chain, @@ -113,7 +109,7 @@ def create_chain() -> RunnableWithMessageHistory: def take_chat_turn( chain: RunnableWithMessageHistory, session_id: str, query_text: str ) -> str: - # TODO limit the number of history messages config = {"configurable": {"session_id": session_id}} result = chain.invoke({"input": query_text}, config=config) - return str(result) + return result + diff --git a/applications/rag/frontend/container/application/rai/__init__.py b/applications/rag/frontend/container/application/rai/__init__.py index cb97ceb52..11f30faf0 100644 --- a/applications/rag/frontend/container/application/rai/__init__.py +++ b/applications/rag/frontend/container/application/rai/__init__.py @@ -13,4 +13,3 @@ # limitations under the License. # This file is required to make Python treat the subfolder as a package - diff --git a/applications/rag/frontend/container/application/rai/retry.py b/applications/rag/frontend/container/application/rai/retry.py index b9092b683..977b9ff28 100644 --- a/applications/rag/frontend/container/application/rai/retry.py +++ b/applications/rag/frontend/container/application/rai/retry.py @@ -27,5 +27,4 @@ def is_retryable(exc): return isinstance(exc, _RETRIABLE_TYPES) - -retry_policy = Retry(predicate=is_retryable) \ No newline at end of file +retry_policy = Retry(predicate=is_retryable) diff --git a/applications/rag/frontend/container/application/vector_storages/__init__.py b/applications/rag/frontend/container/application/vector_storages/__init__.py index 94cde2d79..130806772 100644 --- a/applications/rag/frontend/container/application/vector_storages/__init__.py +++ b/applications/rag/frontend/container/application/vector_storages/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from .cloud_sql import CloudSQLVectorStore diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py index 545f72259..11e08ed28 100644 --- a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -18,21 +18,17 @@ from typing import List, Optional, Iterable, Any -import pg8000 -import sqlalchemy -from sqlalchemy.engine import Engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.sql import func, text - from langchain_core.vectorstores import VectorStore from langchain_core.embeddings import Embeddings from langchain_core.documents import Document -from langchain.text_splitter import CharacterTextSplitter +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from langchain_google_cloud_sql_pg import PostgresVectorStore -from application.models import VectorEmbeddings VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") -INSTANCE_CONNECTION_NAME = os.environ.get("INSTANCE_CONNECTION_NAME", "") +CHUNK_SIZE = 1000 +CHUNK_OVERLAP = 10 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -50,91 +46,41 @@ def from_texts( ): raise NotImplementedError - def __init__(self, embedding: Embeddings, engine: Engine): - self.embedding = embedding - self.engine = engine - self.text_splitter = CharacterTextSplitter( - separator="\n\n", - chunk_size=1024, - chunk_overlap=200, + def __init__(self, embedding_provider, engine): + self.vector_store = PostgresVectorStore.create_sync( + engine=engine, + embedding_service=embedding_provider, + table_name=VECTOR_EMBEDDINGS_TABLE_NAME, ) - - @property - def embeddings(self) -> Embeddings: - return self.embedding + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len + ) + self.embeddings_service = embedding_provider # TODO implement def add_texts( self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any ) -> List[str]: - with self.engine.connect() as conn: - try: - Session = sessionmaker(bind=conn) - session = Session(bind=conn) - for raw_text in texts: - id = uuid.uuid4() - - texts = self.text_splitter.split_text(raw_text) - embeddings = self.embedding.encode(texts).tolist() - vector_embedding = VectorEmbeddings( - id=id, text=texts, text_embedding=embeddings[0] - ) - session.add(vector_embedding) - conn.commit() - - except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: - message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except sqlalchemy.exc.DatabaseError as err: - message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except Exception as err: - raise Exception(f"General error: {err}") + try: + splits = self.splitter.split_documents(texts) + ids = [str(uuid.uuid4()) for _ in range(len(splits))] + self.vector_store.add_documents(splits, ids) + except Exception as e: + logging.info(f"Error: {e}") + raise e # TODO implement similarity search with cosine similarity threshold def similarity_search( self, query: dict, k: int = 4, **kwargs: Any ) -> List[Document]: - with self.engine.connect() as conn: - try: - Session = sessionmaker(bind=conn) - session = Session(bind=conn) - - q = query.get("input") - # embed query & fetch matches - query_emb = self.embedding.embed_query(q) - query_request = ( - "SELECT id, text, text_embedding, 1 - ('[" - + ",".join(map(str, query_emb)) - + "]' <=> text_embedding) AS cosine_similarity FROM " - + VECTOR_EMBEDDINGS_TABLE_NAME - + " ORDER BY cosine_similarity DESC LIMIT " - + str(k) - + ";" - ) - query_results = session.execute(text(query_request)).fetchall() - - print(f"GOT {len(query_results)} results") - - session.commit() - session.close() - - if not query_results: - message = ( - f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} returned empty result" - ) - raise ValueError(message) - - except sqlalchemy.exc.DataError or pg8000.exceptions.DatabaseError as err: - message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except sqlalchemy.exc.DatabaseError as err: - message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except Exception as err: - raise Exception(f"General error: {err}") - - # convert query results into List[Document] - texts = [result[1] for result in query_results] - return [Document(page_content=text) for text in texts] + try: + + query_input = query.get("input") + query_vector = self.embeddings_service.embed_query(query_input) + docs = self.vector_store.similarity_search_by_vector(query_vector, k=4) + return docs + + except Exception as err: + raise Exception(f"General error: {err}") + diff --git a/applications/rag/frontend/container/main.py b/applications/rag/frontend/container/main.py index b029451d6..32a880f89 100644 --- a/applications/rag/frontend/container/main.py +++ b/applications/rag/frontend/container/main.py @@ -1,3 +1,6 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -9,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import uuid import traceback import logging as log import google.cloud.logging as logging -from flask import render_template, request, jsonify, session +from flask import render_template, request, jsonify, session, redirect, url_for from datetime import datetime, timedelta, timezone from application import create_app @@ -31,11 +33,12 @@ clear_chat_history, create_chain, take_chat_turn, + get_chat_history, ) log.basicConfig(level=log.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -SESSION_TIMEOUT_MINUTES = 30 +SESSION_TIMEOUT_MINUTES = 10 # Setup logging logging_client = logging.Client() @@ -47,6 +50,31 @@ llm_chain = create_chain() +@app.before_request +def check_new_session(): + if "session_id" not in session: + # instantiate a new session using a generated UUID + session_id = str(uuid.uuid4()) + session["session_id"] = session_id + + +@app.before_request +def check_inactivity(): + # Inactivity cleanup + if "last_activity" in session: + time_elapsed = datetime.now(timezone.utc) - session["last_activity"] + + if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): + print("Session inactive: Cleaning up resources...") + session_id = session["session_id"] + # TODO: implement garbage collection process for idle sessions that have timed out + clear_chat_history(session_id) + session.clear() + + # Always update the 'last_activity' data + session["last_activity"] = datetime.now(timezone.utc) + + @app.route("/get_nlp_status", methods=["GET"]) def get_nlp_status(): nlp_enabled = nlp_filter.is_nlp_api_enabled() @@ -69,29 +97,30 @@ def get_deidentify_templates(): return jsonify(dlp_filter.list_deidentify_templates_from_parent()) -@app.before_request -def check_new_session(): - if "session_id" not in session: - # instantiate a new session using a generated UUID - session_id = str(uuid.uuid4()) - session["session_id"] = session_id - +@app.route("/get_chat_history", methods=["GET"]) +def get_chat_history_endpoint(): + try: + session_id = session.get("session_id") + history = get_chat_history(session_id) + log.info(history) -@app.before_request -def check_inactivity(): - # Inactivity cleanup - if "last_activity" in session: - time_elapsed = datetime.now(timezone.utc) - session["last_activity"] + response = jsonify({"history_messages": []}) + response.status_code = 200 - if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): - print("Session inactive: Cleaning up resources...") - session_id = session["session_id"] - # TODO: implement garbage collection process for idle sessions that have timed out - clear_chat_history(session_id) - session.clear() + return response - # Always update the 'last_activity' data - session["last_activity"] = datetime.now(timezone.utc) + except Exception as err: + log.info(f"exception from llm: {err}") + traceback.print_exc() + error_traceback = traceback.format_exc() + response = jsonify( + { + "error": "An error occurred", + "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", + } + ) + response.status_code = 500 + return response @app.route("/") @@ -113,8 +142,11 @@ def handlePrompt(): log.info(f"handle user prompt: {user_prompt}") try: + session_id = session.get("session_id") + if not session_id: + return redirect(url_for("index")) response = {} - result = take_chat_turn(llm_chain, session["session_id"], user_prompt) + result = take_chat_turn(llm_chain, session_id, user_prompt) log.info("After the result") log.info(result) response["text"] = result diff --git a/applications/rag/frontend/container/requirements.txt b/applications/rag/frontend/container/requirements.txt index 8b4a16779..d88e4393a 100644 --- a/applications/rag/frontend/container/requirements.txt +++ b/applications/rag/frontend/container/requirements.txt @@ -17,7 +17,6 @@ gunicorn==22.0.0 Werkzeug==3.0.1 langchain==0.1.9 sentence-transformers==2.5.1 -text_generation==0.6.1 google-cloud-dlp==3.12.2 google-cloud-storage==2.9.0 google-cloud-pubsub==2.17.0 From e750d12ddd84edf0d95d3ee6f34f6381b46cdf90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 31 Jul 2024 11:06:01 -0500 Subject: [PATCH 09/46] Fixing issues and updating chat history on frontend --- cloud_sql.py | 85 +++++++++++++++ dlp_filter.py | 143 ++++++++++++++++++++++++ main.py | 191 +++++++++++++++++++++++++++++++++ nlp_filter.py | 81 ++++++++++++++ rag_chain.py | 114 ++++++++++++++++++++ script.js | 292 ++++++++++++++++++++++++++++++++++++++++++++++++++ styles.css | 209 ++++++++++++++++++++++++++++++++++++ 7 files changed, 1115 insertions(+) create mode 100644 cloud_sql.py create mode 100644 dlp_filter.py create mode 100644 main.py create mode 100644 nlp_filter.py create mode 100644 rag_chain.py create mode 100644 script.js create mode 100644 styles.css diff --git a/cloud_sql.py b/cloud_sql.py new file mode 100644 index 000000000..1c0da565b --- /dev/null +++ b/cloud_sql.py @@ -0,0 +1,85 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import logging + +from typing import List, Optional, Iterable, Any + +from langchain_core.vectorstores import VectorStore +from langchain_core.embeddings import Embeddings +from langchain_core.documents import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from langchain_google_cloud_sql_pg import PostgresVectorStore + + +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") +CHUNK_SIZE = 1000 +CHUNK_OVERLAP = 10 + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class CloudSQLVectorStore(VectorStore): + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ): + raise NotImplementedError + + def __init__(self, embedding_provider, engine): + self.vector_store = PostgresVectorStore.create_sync( + engine=engine, + embedding_service=embedding_provider, + table_name=VECTOR_EMBEDDINGS_TABLE_NAME, + ) + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len + ) + self.embeddings_service = embedding_provider + + # TODO implement + def add_texts( + self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any + ) -> List[str]: + try: + splits = self.splitter.split_documents(texts) + ids = [str(uuid.uuid4()) for _ in range(len(splits))] + self.vector_store.add_documents(splits, ids) + except Exception as e: + logging.info(f"Error: {e}") + raise e + + # TODO implement similarity search with cosine similarity threshold + + def similarity_search( + self, query: dict, k: int = 4, **kwargs: Any + ) -> List[Document]: + try: + + query_input = query.get("input") + query_vector = self.embeddings_service.embed_query(query_input) + docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) + return docs + + except Exception as err: + raise Exception(f"General error: {err}") diff --git a/dlp_filter.py b/dlp_filter.py new file mode 100644 index 000000000..ad69a30c3 --- /dev/null +++ b/dlp_filter.py @@ -0,0 +1,143 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging + +import google.cloud.dlp + +from . import retry + +# Convert the project id into a full resource id. +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "NULL") +parent = f"projects/{GCP_PROJECT_ID}" + +# Instantiate a dlp client. +dlp_client = google.cloud.dlp_v2.DlpServiceClient() + +logging.basicConfig( + level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def is_dlp_api_enabled(): + if parent == "NULL": + return False + # Check if the DLP API is enabled + try: + dlp_client.list_info_types( + request={"parent": "en-US"}, retry=retry.retry_policy + ) + return True + except Exception as e: + print(f"Error: {e}") + return False + + +def list_inspect_templates_from_parent(): + try: + # Initialize request argument(s) + request = google.cloud.dlp_v2.ListInspectTemplatesRequest( + parent=parent, + ) + + # Make the request + page_result = dlp_client.list_inspect_templates( + request=request, retry=retry.retry_policy + ) + + name_list = [] + # Handle the response + for response in page_result: + name_list.append(response.name) + return name_list + except Exception as e: + logging.error(e) + raise e + + +def get_inspect_templates_from_name(name): + try: + request = google.cloud.dlp_v2.GetInspectTemplateRequest( + name=name, + ) + + return dlp_client.get_inspect_template(request=request) + except Exception as e: + logging.error(e) + raise e + + +def list_deidentify_templates_from_parent(): + try: + # Initialize request argument(s) + request = google.cloud.dlp_v2.ListDeidentifyTemplatesRequest( + parent=parent, + ) + + # Make the request + page_result = dlp_client.list_deidentify_templates(request=request) + + name_list = [] + # Handle the response + for response in page_result: + name_list.append(response.name) + return name_list + + except Exception as e: + logging.error(e) + raise e + + +def get_deidentify_templates_from_name(name): + try: + request = google.cloud.dlp_v2.GetDeidentifyTemplateRequest( + name=name, + ) + + return dlp_client.get_deidentify_template( + request=request, retry=retry.retry_policy + ) + except Exception as e: + logging.error(e) + raise e + + +def inspect_content(inspect_template_path, deidentify_template_path, input): + try: + inspect_templates = get_inspect_templates_from_name(inspect_template_path) + deidentify_template = get_deidentify_templates_from_name( + deidentify_template_path + ) + + # Construct item + item = {"value": input} + + # Call the API + response = dlp_client.deidentify_content( + request={ + "parent": parent, + "deidentify_config": deidentify_template.deidentify_config, + "inspect_config": inspect_templates.inspect_config, + "item": item, + }, + retry=retry.retry_policy, + ) + + # Print out the results. + print(response.item.value) + return response.item.value + except Exception as e: + logging.error(e) + raise e diff --git a/main.py b/main.py new file mode 100644 index 000000000..987555552 --- /dev/null +++ b/main.py @@ -0,0 +1,191 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +import traceback +import logging as log + +import google.cloud.logging as logging + +from flask import render_template, request, jsonify, session, redirect, url_for +from datetime import datetime, timedelta, timezone + +from application import create_app +from application.rai import ( + dlp_filter, +) # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp +from application.rai import ( + nlp_filter, +) # https://cloud.google.com/natural-language/docs/moderating-text + +from application.rag_langchain.rag_chain import ( + clear_chat_history, + create_chain, + take_chat_turn, + get_chat_history, +) + +log.basicConfig(level=log.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +SESSION_TIMEOUT_MINUTES = 20 + +# Setup logging +logging_client = logging.Client() +logging_client.setup_logging() + +app = create_app() + +# Create llm chain +llm_chain = create_chain() + + +@app.before_request +def check_new_session(): + if "session_id" not in session: + # instantiate a new session using a generated UUID + session_id = str(uuid.uuid4()) + session["session_id"] = session_id + + +@app.before_request +def check_inactivity(): + # Inactivity cleanup + if "last_activity" in session: + time_elapsed = datetime.now(timezone.utc) - session["last_activity"] + + if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): + print("Session inactive: Cleaning up resources...") + session_id = session["session_id"] + # TODO: implement garbage collection process for idle sessions that have timed out + clear_chat_history(session_id) + session.clear() + + # Always update the 'last_activity' data + session["last_activity"] = datetime.now(timezone.utc) + + +@app.route("/get_nlp_status", methods=["GET"]) +def get_nlp_status(): + nlp_enabled = nlp_filter.is_nlp_api_enabled() + return jsonify({"nlpEnabled": nlp_enabled}) + + +@app.route("/get_dlp_status", methods=["GET"]) +def get_dlp_status(): + dlp_enabled = dlp_filter.is_dlp_api_enabled() + return jsonify({"dlpEnabled": dlp_enabled}) + + +@app.route("/get_inspect_templates") +def get_inspect_templates(): + return jsonify(dlp_filter.list_inspect_templates_from_parent()) + + +@app.route("/get_deidentify_templates") +def get_deidentify_templates(): + return jsonify(dlp_filter.list_deidentify_templates_from_parent()) + + +@app.route("/get_chat_history", methods=["GET"]) +def get_chat_history_endpoint(): + try: + session_id = session.get("session_id") + if not session_id: + return redirect(url_for("index")) + + history = get_chat_history(session_id) + + messages_response = [] + for message in history.messages: + data = {"prompt": message.type, "message": message.content} + messages_response.append(data) + + response = jsonify({"history_messages": messages_response}) + response.status_code = 200 + + return response + + except Exception as err: + log.info(f"exception from llm: {err}") + traceback.print_exc() + error_traceback = traceback.format_exc() + response = jsonify( + { + "error": "An error occurred", + "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", + } + ) + response.status_code = 500 + return response + + +@app.route("/") +def index(): + return render_template("index.html") + + +@app.route("/prompt", methods=["POST"]) +def handlePrompt(): + # TODO on page refresh, load chat history into browser. + session["last_activity"] = datetime.now(timezone.utc) + data = request.get_json() + warnings = [] + + if "prompt" not in data: + return "missing required prompt", 400 + + user_prompt = data["prompt"] + log.info(f"handle user prompt: {user_prompt}") + + try: + session_id = session.get("session_id") + if not session_id: + return redirect(url_for("index")) + response = {} + result = take_chat_turn(llm_chain, session_id, user_prompt) + response["text"] = result + + # TODO: enable filtering in chain + if "nlpFilterLevel" in data: + if nlp_filter.is_content_inappropriate( + response["text"], data["nlpFilterLevel"] + ): + response["text"] = "The response is deemed inappropriate for display." + return {"response": response} + if "inspectTemplate" in data and "deidentifyTemplate" in data: + inspect_template_path = data["inspectTemplate"] + deidentify_template_path = data["deidentifyTemplate"] + if inspect_template_path != "" and deidentify_template_path != "": + # filter the output with inspect setting. Customer can pick any category from https://cloud.google.com/dlp/docs/concepts-infotypes + response["text"] = dlp_filter.inspect_content( + inspect_template_path, deidentify_template_path, response["text"] + ) + + if warnings: + response["warnings"] = warnings + return {"response": response} + + except Exception as err: + log.info(f"exception from llm: {err}") + traceback.print_exc() + error_traceback = traceback.format_exc() + response = jsonify( + { + "warnings": warnings, + "error": "An error occurred", + "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", + } + ) + response.status_code = 500 + return response diff --git a/nlp_filter.py b/nlp_filter.py new file mode 100644 index 000000000..3b384a910 --- /dev/null +++ b/nlp_filter.py @@ -0,0 +1,81 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging + +import google.cloud.language_v1 as language + +from . import retry + +# Convert the project id into a full resource id. +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "NULL") +parent = f"projects/{GCP_PROJECT_ID}" + +# Instantiate a nlp client. +nature_language_client = language.LanguageServiceClient() + +logging.basicConfig( + level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def is_nlp_api_enabled(): + if parent == "NULL": + return False + # Check if the DLP API is enabled + try: + sum_moderation_confidences("test") + return True + except Exception as e: + print(f"Error: {e}") + return False + + +def sum_moderation_confidences(text): + try: + document = language.types.Document( + content=text, type_=language.types.Document.Type.PLAIN_TEXT + ) + + request = language.ModerateTextRequest( + document=document, + ) + # Detects the sentiment of the text + response = nature_language_client.moderate_text( + request=request, retry=retry.retry_policy + ) + print(f"get response: {response}") + # Parse response and sum the confidences of moderation, the categories are from https://cloud.google.com/natural-language/docs/moderating-text + largest_confidence = 0.0 + excluding_names = ["Health", "Politics", "Finance", "Legal"] + for category in response.moderation_categories: + if category.name in excluding_names: + continue + if category.confidence > largest_confidence: + largest_confidence = category.confidence + + print(f"largest confidence is: {largest_confidence}") + return int(largest_confidence * 100) + except Exception as e: + logging.error(e) + raise e + + +def is_content_inappropriate(text, nlp_filter_level): + try: + return sum_moderation_confidences(text) > (100 - int(nlp_filter_level)) + except Exception as e: + logging.error(e) + raise e diff --git a/rag_chain.py b/rag_chain.py new file mode 100644 index 000000000..f290021ff --- /dev/null +++ b/rag_chain.py @@ -0,0 +1,114 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableParallel, RunnableLambda +from langchain_core.runnables.history import RunnableWithMessageHistory + +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_google_cloud_sql_pg import PostgresChatMessageHistory + + +from application.cloud_sql.cloud_sql import ( + CHAT_HISTORY_TABLE_NAME, + create_sync_postgres_engine, +) +from application.rag_langchain.huggingface_inference_model import ( + HuggingFaceCustomChatModel, +) +from application.vector_storages import CloudSQLVectorStore + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +QUESTION = "input" +HISTORY = "chat_history" +CONTEXT = "context" + +SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings + +template_str = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. +Improve upon your previous answers using History, a list of messages. +Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. +Stick to the facts by basing your answers off of the Context provided. +Be brief in answering. +\n\n +Context: {context} +""" + +prompt = ChatPromptTemplate.from_messages( + [ + ("system", template_str), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) + +engine = create_sync_postgres_engine() + + +def get_chat_history(session_id: str) -> PostgresChatMessageHistory: + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + + logging.info( + f"Retrieving history for session {session_id} with {len(history.messages)}" + ) + return history + + +def clear_chat_history(session_id: str): + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + history.clear() + + +def create_chain() -> RunnableWithMessageHistory: + model = HuggingFaceCustomChatModel() + + langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) + vector_store = CloudSQLVectorStore(langchain_embed, engine) + + retriever = vector_store.as_retriever() + + setup_and_retrieval = RunnableParallel( + { + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]), + } + ) + + chain = setup_and_retrieval | prompt | model + chain_with_history = RunnableWithMessageHistory( + chain, + get_chat_history, + input_messages_key=QUESTION, + history_messages_key=HISTORY, + output_messages_key="output", + ) + return chain_with_history + + +def take_chat_turn( + chain: RunnableWithMessageHistory, session_id: str, query_text: str +) -> str: + config = {"configurable": {"session_id": session_id}} + result = chain.invoke({"input": query_text}, config=config) + return result diff --git a/script.js b/script.js new file mode 100644 index 000000000..6fec15f99 --- /dev/null +++ b/script.js @@ -0,0 +1,292 @@ +/* Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +function onReady() { + autoResizeTextarea(); + populateDropdowns(); + updateNLPValue(); + loadPreviousMessages(); + + document.getElementById("prompt").addEventListener("keydown", (e) => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + e.target.form.requestSubmit(); + } + }); + + // Handle the chat form submission + document.getElementById("form").addEventListener("submit", function (e) { + e.preventDefault(); + + var promptInput = document.getElementById("prompt"); + var prompt = promptInput.value; + if (prompt === "") { + return; + } + promptInput.value = ""; + + var chatEl = document.getElementById("chat"); + var promptEl = Object.assign(document.createElement("p"), { + classList: ["prompt"], + }); + promptEl.textContent = prompt; + chatEl.appendChild(promptEl); + + var responseEl = Object.assign(document.createElement("p"), { + classList: ["response"], + }); + chatEl.appendChild(responseEl); + chatEl.scrollTop = chatEl.scrollHeight; // Scroll to bottom + enableForm(false); + + // Collect filter data + let data = { + prompt: prompt, + }; + + if (document.getElementById("toggle-nlp-filter-section").checked) { + data.nlpFilterLevel = document.getElementById("nlp-range").value; + } + + if (document.getElementById("toggle-dlp-filter-section").checked) { + data.inspectTemplate = document.getElementById( + "inspect-template-dropdown" + ).value; + data.deidentifyTemplate = document.getElementById( + "deidentify-template-dropdown" + ).value; + } + var body = JSON.stringify(data); + + // Send data to the server + fetch("/prompt", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: body, + }) + .then((response) => { + if (!response.ok) { + return response.json().then((errorData) => { + throw new Error(errorData.errorMessage); + }); + } + return response.json(); + }) + .then((data) => { + var content = data.response.text; + if (data.response.warnings && data.response.warnings.length > 0) { + responseEl.classList.replace("response", "warning"); + content += "\n\nWarning: " + data.response.warnings.join("\n") + "\n"; + } + responseEl.textContent = content; + }) + .catch((err) => { + responseEl.classList.replace("response", "error"); + responseEl.textContent = err.message; + }) + .finally(() => enableForm(true)); + }); + + document + .getElementById("toggle-dlp-filter-section") + .addEventListener("change", function () { + fetchDLPEnabled(); + var inspectDropdown = document.getElementById( + "inspect-template-dropdown" + ); + var deidentifyDropdown = document.getElementById( + "deidentify-template-dropdown" + ); + + // Check the Inspect Template Dropdown + if (inspectDropdown.options.length <= 0) { + inspectDropdown.style.display = "none"; // Hide Dropdown + document.getElementById("inspect-template-msg").style.display = "block"; // Show Message + } else { + inspectDropdown.style.display = "block"; // Show Dropdown + document.getElementById("inspect-template-msg").style.display = "none"; // Hide Message + } + + // Check the De-identify Template Dropdown + if (deidentifyDropdown.options.length <= 0) { + deidentifyDropdown.style.display = "none"; // Hide Dropdown + document.getElementById("deidentify-template-msg").style.display = + "block"; // Show Message + } else { + deidentifyDropdown.style.display = "block"; // Show Dropdown + document.getElementById("deidentify-template-msg").style.display = + "none"; // Hide Message + } + }); + + document + .getElementById("toggle-nlp-filter-section") + .addEventListener("change", function () { + fetchNLPEnabled(); + }); +} +if (document.readyState != "loading") onReady(); +else document.addEventListener("DOMContentLoaded", onReady); + +function enableForm(enabled) { + var promptEl = document.getElementById("prompt"); + promptEl.toggleAttribute("disabled", !enabled); + if (enabled) setTimeout(() => promptEl.focus(), 0); + + var submitEl = document.getElementById("submit"); + submitEl.toggleAttribute("disabled", !enabled); + submitEl.textContent = enabled ? "Submit" : "..."; +} + +function autoResizeTextarea() { + var textarea = document.getElementById("prompt"); + textarea.addEventListener("input", function () { + this.style.height = "auto"; + this.style.height = this.scrollHeight + "px"; + }); +} + +// Function to handle the visibility of filter section +function toggleNlpFilterSection(nlpEnabled) { + var filterOptions = document.getElementById("nlp-filter-section"); + var nlpCheckbox = document.getElementById("toggle-nlp-filter-section"); + + if (nlpEnabled && nlpCheckbox.checked) { + filterOptions.style.display = "block"; + } else { + filterOptions.style.display = "none"; + } +} + +function updateNLPValue() { + const rangeInput = document.getElementById("nlp-range"); + const valueDisplay = document.getElementById("nlp-value"); + + // Function to update the slider's display value and color + const updateSliderAppearance = (value) => { + // Update the display text + valueDisplay.textContent = value; + + // Determine the color based on the value + let color; + if (value <= 25) { + color = "#4285F4"; // Blue + } else if (value <= 50) { + color = "#34A853"; // Green + } else if (value <= 75) { + color = "#FBBC05"; // Yellow + } else { + color = "#EA4335"; // Red + } + + // Apply the color to the slider through a gradient + // This gradient visually fills the track up to the thumb's current position + const percentage = + ((value - rangeInput.min) / (rangeInput.max - rangeInput.min)) * 100; + rangeInput.style.background = `linear-gradient(90deg, ${color} ${percentage}%, #ddd ${percentage}%)`; + rangeInput.style.setProperty("--thumb-color", color); + }; + + // Initialize the slider's appearance + updateSliderAppearance(rangeInput.value); + + // Update slider's appearance whenever its value changes + rangeInput.addEventListener("input", (event) => { + updateSliderAppearance(event.target.value); + }); +} + +function fetchNLPEnabled() { + fetch("/get_nlp_status") + .then((response) => response.json()) + .then((data) => { + var nlpEnabled = data.nlpEnabled; + + toggleNlpFilterSection(nlpEnabled); + }) + .catch((error) => console.error("Error fetching NLP status:", error)); +} + +// Function to handle the visibility of filter section +function toggleDLPFilterSection(dlpEnabled) { + var filterOptions = document.getElementById("dlp-filter-section"); + var dlpCheckbox = document.getElementById("toggle-dlp-filter-section"); + if (dlpEnabled && dlpCheckbox.checked) { + filterOptions.style.display = "block"; + } else { + filterOptions.style.display = "none"; + } +} + +function fetchDLPEnabled() { + fetch("/get_dlp_status") + .then((response) => response.json()) + .then((data) => { + var dlpEnabled = data.dlpEnabled; + + toggleDLPFilterSection(dlpEnabled); + }) + .catch((error) => console.error("Error fetching DLP status:", error)); +} + +// Function to populate dropdowns +function populateDropdowns() { + fetch("/get_inspect_templates") + .then((response) => response.json()) + .then((data) => { + const inspectDropdown = document.getElementById( + "inspect-template-dropdown" + ); + data.forEach((template) => { + let option = new Option(template, template); + inspectDropdown.add(option); + }); + }) + .catch((error) => console.error("Error loading inspect templates:", error)); + + fetch("/get_deidentify_templates") + .then((response) => response.json()) + .then((data) => { + const deidentifyDropdown = document.getElementById( + "deidentify-template-dropdown" + ); + data.forEach((template) => { + let option = new Option(template, template); + deidentifyDropdown.add(option); + }); + }) + .catch((error) => + console.error("Error loading deidentify templates:", error) + ); +} + +function loadPreviousMessages() { + fetch("/get_chat_history") + .then((response) => response.json()) + .then((data) => { + const { history_messages } = data; + var chatEl = document.getElementById("chat"); + history_messages.map(({ prompt, message }) => { + var promptEl = Object.assign(document.createElement("p"), { + classList: ["previous_message"], + }); + promptEl.textContent = `${prompt} : ${message}`; + chatEl.appendChild(promptEl); + }); + }) + .catch((error) => + console.error("Error getting previous chat messages:", error) + ); +} diff --git a/styles.css b/styles.css new file mode 100644 index 000000000..d2bf11352 --- /dev/null +++ b/styles.css @@ -0,0 +1,209 @@ +/* Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* styles.css */ +body { + font-family: Arial, sans-serif; + text-align: center; + + max-width: 2000px; + margin: 0 auto; +} + +.content-container { + display: flex; + max-width: 2000px; + margin: 0 auto; +} + +#chat-and-form { + width: 70%; +} + +#chat { + height: 80vh; + padding: 10px; + margin-top: 20px; + margin-bottom: 10px; + display: block; + + text-align: left; + overflow-y: scroll; + + border: 1px solid #222; + border-radius: 2px; +} + +#chat p { + white-space: pre-wrap; + + padding: 8px; + border-radius: 4px; +} + +#chat p:first-child { + margin-top: 0; +} + +p.prompt { + background-color: rgba(50, 255, 0, 0.1); +} + +p.prompt::before { + content: 'Prompt: '; + font-weight: bold; +} + +p.instruct{ + background-color: rgba(255, 255, 0, 0.2); +} + + +p.instruct::before { + content: 'Instructions: '; + font-weight: bold; +} + +p.previous_message { + background-color: rgba(73, 240, 240, 0.753); +} + +p.previous_message::before { + content: 'Previous message: '; + font-weight: bold; +} + +p.response { + background-color: rgba(0, 134, 255, 0.1); +} + +p.response::before { + content: 'Response: '; + font-weight: bold; +} + +p.warning { + background-color: rgba(255, 229, 100, 0.1); +} + +p.error::before { + content: 'Warning: '; + font-weight: bold; +} + +p.error { + background-color: rgba(255, 0, 0, 0.1); +} + +p.error::before { + content: 'Error: '; + font-weight: bold; +} + +form { + display: flex; + width: 100%; +} + +#prompt { + flex: 1; + padding: 10px; + margin-right: 10px; + border-radius: 2px; +} + +input[type="submit"] { + padding: 10px 20px; + background-color: rgb(76,175,80); + color: white; + border: none; + border-radius: 2px; +} + +input[type="submit"]:disabled { + background-color: rgba(0, 0, 0, 0.2); +} + +/* Filter Section Styles */ +#filter-section { + width: 30%; + margin-right: 2%; /* Space between filter section and the next section, adjust as needed */ + border: 1px solid #ccc; /* Optional: adds a border around the filter section */ + padding: 20px; /* Adds some space inside the filter section */ + box-shadow: 0 2px 4px rgba(0,0,0,0.1); /* Optional: adds a slight shadow for depth */ +} + +#dlp-filter-section, +#nlp-filter-section { + margin-top: 20px; +} + +select { + width: 100%; + padding: 10px; + margin-top: 10px; + border: 1px solid #ccc; /* Adds a border to the dropdowns */ + border-radius: 4px; /* Rounds the corners of the dropdowns */ +} + +label { + display: block; + margin-top: 15px; + font-weight: bold; +} + +input[type="checkbox"] { + margin-right: 10px; /* Space between checkbox and its label */ +} + +#nlp-range { + width: 100%; + margin-top: 10px; /* Adjusted for consistency */ +} + +input[type="range"] { + -webkit-appearance: none; + appearance: none; + width: 100%; + height: 8px; + background: #ddd; + outline: none; + opacity: 1; + -webkit-transition: .2s; + transition: opacity .2s; +} + +input[type="range"]::-webkit-slider-thumb { + -webkit-appearance: none; + appearance: none; + width: 25px; + height: 25px; + background: var(--thumb-color); + cursor: pointer; +} + +input[type="range"]::-moz-range-thumb { + width: 25px; + height: 25px; + background: var(--thumb-color); + cursor: pointer; +} + +input[type="range"]::-ms-thumb { + width: 25px; + height: 25px; + background: var(--thumb-color); + cursor: pointer; +} \ No newline at end of file From a000c4637c1260c60a7c7b5ab7f81e0874a49778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 31 Jul 2024 16:10:47 +0000 Subject: [PATCH 10/46] Fixing files on working tree --- .../application/cloud_sql/cloud_sql.py | 177 +++--- .../application/rag_langchain/rag_chain.py | 229 ++++---- .../container/application/rai/dlp_filter.py | 254 +++++---- .../container/application/rai/nlp_filter.py | 147 ++--- .../container/application/static/script.js | 539 ++++++++++-------- .../container/application/static/styles.css | 406 ++++++------- applications/rag/frontend/container/main.py | 377 ++++++------ cloud_sql.py | 85 --- dlp_filter.py | 143 ----- main.py | 191 ------- nlp_filter.py | 81 --- rag_chain.py | 114 ---- script.js | 292 ---------- styles.css | 209 ------- 14 files changed, 1114 insertions(+), 2130 deletions(-) delete mode 100644 cloud_sql.py delete mode 100644 dlp_filter.py delete mode 100644 main.py delete mode 100644 nlp_filter.py delete mode 100644 rag_chain.py delete mode 100644 script.js delete mode 100644 styles.css diff --git a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py index cf3c4d7ed..1c0da565b 100644 --- a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py +++ b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py @@ -1,92 +1,85 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import logging - -from google.cloud.sql.connector import IPTypes - -from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - -ENVIRONMENT = os.environ.get("ENVIRONMENT") - -GCP_PROJECT_ID = os.environ.get("PROJECT_ID") -GCP_CLOUD_SQL_REGION = os.environ.get("CLOUDSQL_INSTANCE_REGION") -GCP_CLOUD_SQL_INSTANCE = os.environ.get("CLOUDSQL_INSTANCE") - -INSTANCE_CONNECTION_NAME = ( - f"{GCP_PROJECT_ID}:{GCP_CLOUD_SQL_REGION}:{GCP_CLOUD_SQL_INSTANCE}" -) - -DB_NAME = os.environ.get("DB_NAME", "pgvector-database") -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") -CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") - -VECTOR_DIMENSION = os.environ.get("VECTOR_DIMENSION", 384) - -try: - db_username_file = open("/etc/secret-volume/username", "r") - DB_USER = db_username_file.read() - db_username_file.close() - - db_password_file = open("/etc/secret-volume/password", "r") - DB_PASS = db_password_file.read() - db_password_file.close() -except: - DB_USER = os.environ.get("DB_USERNAME", "postgres") - DB_PASS = os.environ.get("DB_PASS", "postgres") - - -def create_sync_postgres_engine(): - engine = PostgresEngine.from_instance( - project_id=GCP_PROJECT_ID, - region=GCP_CLOUD_SQL_REGION, - instance=GCP_CLOUD_SQL_INSTANCE, - database=DB_NAME, - user=DB_USER, - password=DB_PASS, - ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE, - ) - try: - engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) - engine.init_vectorstore_table( - VECTOR_EMBEDDINGS_TABLE_NAME, - vector_size=VECTOR_DIMENSION, - overwrite_existing=False, - ) - except Exception as e: - logging.info(f"Error: {e}") - - return engine - - -def create_sync_postgres_vector_store(engine, embedding_provider): - vector_store = PostgresVectorStore.create_sync( - engine=engine, - embedding_service=embedding_provider, - table_name=VECTOR_EMBEDDINGS_TABLE_NAME, - ) - - return vector_store +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import logging + +from typing import List, Optional, Iterable, Any + +from langchain_core.vectorstores import VectorStore +from langchain_core.embeddings import Embeddings +from langchain_core.documents import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from langchain_google_cloud_sql_pg import PostgresVectorStore + + +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") +CHUNK_SIZE = 1000 +CHUNK_OVERLAP = 10 + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class CloudSQLVectorStore(VectorStore): + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ): + raise NotImplementedError + + def __init__(self, embedding_provider, engine): + self.vector_store = PostgresVectorStore.create_sync( + engine=engine, + embedding_service=embedding_provider, + table_name=VECTOR_EMBEDDINGS_TABLE_NAME, + ) + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len + ) + self.embeddings_service = embedding_provider + + # TODO implement + def add_texts( + self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any + ) -> List[str]: + try: + splits = self.splitter.split_documents(texts) + ids = [str(uuid.uuid4()) for _ in range(len(splits))] + self.vector_store.add_documents(splits, ids) + except Exception as e: + logging.info(f"Error: {e}") + raise e + + # TODO implement similarity search with cosine similarity threshold + + def similarity_search( + self, query: dict, k: int = 4, **kwargs: Any + ) -> List[Document]: + try: + + query_input = query.get("input") + query_vector = self.embeddings_service.embed_query(query_input) + docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) + return docs + + except Exception as err: + raise Exception(f"General error: {err}") diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index 8e9a80999..f290021ff 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -1,115 +1,114 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import RunnableParallel, RunnableLambda -from langchain_core.runnables.history import RunnableWithMessageHistory - -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_google_cloud_sql_pg import PostgresChatMessageHistory - - -from application.cloud_sql.cloud_sql import ( - CHAT_HISTORY_TABLE_NAME, - create_sync_postgres_engine, -) -from application.rag_langchain.huggingface_inference_model import ( - HuggingFaceCustomChatModel, -) -from application.vector_storages import CloudSQLVectorStore - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - -QUESTION = "input" -HISTORY = "chat_history" -CONTEXT = "context" - -SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings - -template_str = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. -Improve upon your previous answers using History, a list of messages. -Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. -Stick to the facts by basing your answers off of the Context provided. -Be brief in answering. -\n\n -Context: {context} -""" - -prompt = ChatPromptTemplate.from_messages( - [ - ("system", template_str), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] -) - -engine = create_sync_postgres_engine() - - -def get_chat_history(session_id: str) -> PostgresChatMessageHistory: - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - - logging.info( - f"Retrieving history for session {session_id} with {len(history.messages)}" - ) - return history - - -def clear_chat_history(session_id: str): - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - history.clear() - - -def create_chain() -> RunnableWithMessageHistory: - model = HuggingFaceCustomChatModel() - - langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CloudSQLVectorStore(langchain_embed, engine) - - retriever = vector_store.as_retriever() - - setup_and_retrieval = RunnableParallel( - { - "context": retriever, - QUESTION: RunnableLambda(lambda d: d[QUESTION]), - HISTORY: RunnableLambda(lambda d: d[HISTORY]), - } - ) - - chain = setup_and_retrieval | prompt | model - chain_with_history = RunnableWithMessageHistory( - chain, - get_chat_history, - input_messages_key=QUESTION, - history_messages_key=HISTORY, - output_messages_key="output", - ) - return chain_with_history - - -def take_chat_turn( - chain: RunnableWithMessageHistory, session_id: str, query_text: str -) -> str: - config = {"configurable": {"session_id": session_id}} - result = chain.invoke({"input": query_text}, config=config) - return result - +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableParallel, RunnableLambda +from langchain_core.runnables.history import RunnableWithMessageHistory + +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_google_cloud_sql_pg import PostgresChatMessageHistory + + +from application.cloud_sql.cloud_sql import ( + CHAT_HISTORY_TABLE_NAME, + create_sync_postgres_engine, +) +from application.rag_langchain.huggingface_inference_model import ( + HuggingFaceCustomChatModel, +) +from application.vector_storages import CloudSQLVectorStore + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +QUESTION = "input" +HISTORY = "chat_history" +CONTEXT = "context" + +SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings + +template_str = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. +Improve upon your previous answers using History, a list of messages. +Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. +Stick to the facts by basing your answers off of the Context provided. +Be brief in answering. +\n\n +Context: {context} +""" + +prompt = ChatPromptTemplate.from_messages( + [ + ("system", template_str), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) + +engine = create_sync_postgres_engine() + + +def get_chat_history(session_id: str) -> PostgresChatMessageHistory: + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + + logging.info( + f"Retrieving history for session {session_id} with {len(history.messages)}" + ) + return history + + +def clear_chat_history(session_id: str): + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + history.clear() + + +def create_chain() -> RunnableWithMessageHistory: + model = HuggingFaceCustomChatModel() + + langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) + vector_store = CloudSQLVectorStore(langchain_embed, engine) + + retriever = vector_store.as_retriever() + + setup_and_retrieval = RunnableParallel( + { + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]), + } + ) + + chain = setup_and_retrieval | prompt | model + chain_with_history = RunnableWithMessageHistory( + chain, + get_chat_history, + input_messages_key=QUESTION, + history_messages_key=HISTORY, + output_messages_key="output", + ) + return chain_with_history + + +def take_chat_turn( + chain: RunnableWithMessageHistory, session_id: str, query_text: str +) -> str: + config = {"configurable": {"session_id": session_id}} + result = chain.invoke({"input": query_text}, config=config) + return result diff --git a/applications/rag/frontend/container/application/rai/dlp_filter.py b/applications/rag/frontend/container/application/rai/dlp_filter.py index 4a22b976f..ad69a30c3 100644 --- a/applications/rag/frontend/container/application/rai/dlp_filter.py +++ b/applications/rag/frontend/container/application/rai/dlp_filter.py @@ -1,111 +1,143 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import google.cloud.dlp -from . import retry - -# Convert the project id into a full resource id. -parent = os.environ.get("PROJECT_ID", "NULL") - -# Instantiate a dlp client. -dlp_client = google.cloud.dlp_v2.DlpServiceClient() - - -def is_dlp_api_enabled(): - if parent == "NULL": - return False - # Check if the DLP API is enabled - try: - dlp_client.list_info_types( - request={"parent": "en-US"}, retry=retry.retry_policy - ) - return True - except Exception as e: - print(f"Error: {e}") - return False - - -def list_inspect_templates_from_parent(): - # Initialize request argument(s) - request = google.cloud.dlp_v2.ListInspectTemplatesRequest( - parent=parent, - ) - - # Make the request - page_result = dlp_client.list_inspect_templates( - request=request, retry=retry.retry_policy - ) - - name_list = [] - # Handle the response - for response in page_result: - name_list.append(response.name) - return name_list - - -def get_inspect_templates_from_name(name): - request = google.cloud.dlp_v2.GetInspectTemplateRequest( - name=name, - ) - - return dlp_client.get_inspect_template(request=request) - - -def list_deidentify_templates_from_parent(): - # Initialize request argument(s) - request = google.cloud.dlp_v2.ListDeidentifyTemplatesRequest( - parent=parent, - ) - - # Make the request - page_result = dlp_client.list_deidentify_templates(request=request) - - name_list = [] - # Handle the response - for response in page_result: - name_list.append(response.name) - return name_list - - -def get_deidentify_templates_from_name(name): - request = google.cloud.dlp_v2.GetDeidentifyTemplateRequest( - name=name, - ) - - return dlp_client.get_deidentify_template(request=request, retry=retry.retry_policy) - - -def inspect_content(inspect_template_path, deidentify_template_path, input): - inspect_templates = get_inspect_templates_from_name(inspect_template_path) - deidentify_template = get_deidentify_templates_from_name(deidentify_template_path) - - # Construct item - item = {"value": input} - - # Call the API - response = dlp_client.deidentify_content( - request={ - "parent": parent, - "deidentify_config": deidentify_template.deidentify_config, - "inspect_config": inspect_templates.inspect_config, - "item": item, - }, - retry=retry.retry_policy, - ) - - # Print out the results. - print(response.item.value) - return response.item.value - +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging + +import google.cloud.dlp + +from . import retry + +# Convert the project id into a full resource id. +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "NULL") +parent = f"projects/{GCP_PROJECT_ID}" + +# Instantiate a dlp client. +dlp_client = google.cloud.dlp_v2.DlpServiceClient() + +logging.basicConfig( + level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def is_dlp_api_enabled(): + if parent == "NULL": + return False + # Check if the DLP API is enabled + try: + dlp_client.list_info_types( + request={"parent": "en-US"}, retry=retry.retry_policy + ) + return True + except Exception as e: + print(f"Error: {e}") + return False + + +def list_inspect_templates_from_parent(): + try: + # Initialize request argument(s) + request = google.cloud.dlp_v2.ListInspectTemplatesRequest( + parent=parent, + ) + + # Make the request + page_result = dlp_client.list_inspect_templates( + request=request, retry=retry.retry_policy + ) + + name_list = [] + # Handle the response + for response in page_result: + name_list.append(response.name) + return name_list + except Exception as e: + logging.error(e) + raise e + + +def get_inspect_templates_from_name(name): + try: + request = google.cloud.dlp_v2.GetInspectTemplateRequest( + name=name, + ) + + return dlp_client.get_inspect_template(request=request) + except Exception as e: + logging.error(e) + raise e + + +def list_deidentify_templates_from_parent(): + try: + # Initialize request argument(s) + request = google.cloud.dlp_v2.ListDeidentifyTemplatesRequest( + parent=parent, + ) + + # Make the request + page_result = dlp_client.list_deidentify_templates(request=request) + + name_list = [] + # Handle the response + for response in page_result: + name_list.append(response.name) + return name_list + + except Exception as e: + logging.error(e) + raise e + + +def get_deidentify_templates_from_name(name): + try: + request = google.cloud.dlp_v2.GetDeidentifyTemplateRequest( + name=name, + ) + + return dlp_client.get_deidentify_template( + request=request, retry=retry.retry_policy + ) + except Exception as e: + logging.error(e) + raise e + + +def inspect_content(inspect_template_path, deidentify_template_path, input): + try: + inspect_templates = get_inspect_templates_from_name(inspect_template_path) + deidentify_template = get_deidentify_templates_from_name( + deidentify_template_path + ) + + # Construct item + item = {"value": input} + + # Call the API + response = dlp_client.deidentify_content( + request={ + "parent": parent, + "deidentify_config": deidentify_template.deidentify_config, + "inspect_config": inspect_templates.inspect_config, + "item": item, + }, + retry=retry.retry_policy, + ) + + # Print out the results. + print(response.item.value) + return response.item.value + except Exception as e: + logging.error(e) + raise e diff --git a/applications/rag/frontend/container/application/rai/nlp_filter.py b/applications/rag/frontend/container/application/rai/nlp_filter.py index 0fbf2688e..3b384a910 100644 --- a/applications/rag/frontend/container/application/rai/nlp_filter.py +++ b/applications/rag/frontend/container/application/rai/nlp_filter.py @@ -1,66 +1,81 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import google.cloud.language_v1 as language -from . import retry - -# Convert the project id into a full resource id. -parent = os.environ.get("PROJECT_ID", "NULL") - -# Instantiate a nlp client. -nature_language_client = language.LanguageServiceClient() - - -def is_nlp_api_enabled(): - if parent == "NULL": - return False - # Check if the DLP API is enabled - try: - sum_moderation_confidences("test") - return True - except Exception as e: - print(f"Error: {e}") - return False - - -def sum_moderation_confidences(text): - document = language.types.Document( - content=text, type_=language.types.Document.Type.PLAIN_TEXT - ) - - request = language.ModerateTextRequest( - document=document, - ) - # Detects the sentiment of the text - response = nature_language_client.moderate_text( - request=request, retry=retry.retry_policy - ) - print(f"get response: {response}") - # Parse response and sum the confidences of moderation, the categories are from https://cloud.google.com/natural-language/docs/moderating-text - largest_confidence = 0.0 - excluding_names = ["Health", "Politics", "Finance", "Legal"] - for category in response.moderation_categories: - if category.name in excluding_names: - continue - if category.confidence > largest_confidence: - largest_confidence = category.confidence - - print(f"largest confidence is: {largest_confidence}") - return int(largest_confidence * 100) - - -def is_content_inappropriate(text, nlp_filter_level): - return sum_moderation_confidences(text) > (100 - int(nlp_filter_level)) - +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging + +import google.cloud.language_v1 as language + +from . import retry + +# Convert the project id into a full resource id. +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "NULL") +parent = f"projects/{GCP_PROJECT_ID}" + +# Instantiate a nlp client. +nature_language_client = language.LanguageServiceClient() + +logging.basicConfig( + level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def is_nlp_api_enabled(): + if parent == "NULL": + return False + # Check if the DLP API is enabled + try: + sum_moderation_confidences("test") + return True + except Exception as e: + print(f"Error: {e}") + return False + + +def sum_moderation_confidences(text): + try: + document = language.types.Document( + content=text, type_=language.types.Document.Type.PLAIN_TEXT + ) + + request = language.ModerateTextRequest( + document=document, + ) + # Detects the sentiment of the text + response = nature_language_client.moderate_text( + request=request, retry=retry.retry_policy + ) + print(f"get response: {response}") + # Parse response and sum the confidences of moderation, the categories are from https://cloud.google.com/natural-language/docs/moderating-text + largest_confidence = 0.0 + excluding_names = ["Health", "Politics", "Finance", "Legal"] + for category in response.moderation_categories: + if category.name in excluding_names: + continue + if category.confidence > largest_confidence: + largest_confidence = category.confidence + + print(f"largest confidence is: {largest_confidence}") + return int(largest_confidence * 100) + except Exception as e: + logging.error(e) + raise e + + +def is_content_inappropriate(text, nlp_filter_level): + try: + return sum_moderation_confidences(text) > (100 - int(nlp_filter_level)) + except Exception as e: + logging.error(e) + raise e diff --git a/applications/rag/frontend/container/application/static/script.js b/applications/rag/frontend/container/application/static/script.js index 3c43c8588..6fec15f99 100644 --- a/applications/rag/frontend/container/application/static/script.js +++ b/applications/rag/frontend/container/application/static/script.js @@ -1,247 +1,292 @@ -/* Copyright 2024 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -function onReady() { - autoResizeTextarea() - populateDropdowns() - updateNLPValue() - - document.getElementById("prompt").addEventListener("keydown", e => { - if (e.key === "Enter" && !e.shiftKey) { - e.preventDefault(); - e.target.form.requestSubmit(); - } - }); - - - // Handle the chat form submission - document.getElementById("form").addEventListener("submit", function(e) { - e.preventDefault(); - - var promptInput = document.getElementById("prompt"); - var prompt = promptInput.value; - if (prompt === "") { - return; - } - promptInput.value = ""; - - var chatEl = document.getElementById("chat"); - var promptEl = Object.assign(document.createElement("p"), {classList: ["prompt"]}); - promptEl.textContent = prompt; - chatEl.appendChild(promptEl); - - var responseEl = Object.assign(document.createElement("p"), {classList: ["response"]}); - chatEl.appendChild(responseEl); - chatEl.scrollTop = chatEl.scrollHeight; // Scroll to bottom - enableForm(false); - - // Collect filter data - let data = { - prompt: prompt, - } - - if (document.getElementById('toggle-nlp-filter-section').checked) { - data.nlpFilterLevel = document.getElementById("nlp-range").value - } - - if (document.getElementById('toggle-dlp-filter-section').checked) { - data.inspectTemplate = document.getElementById('inspect-template-dropdown').value; - data.deidentifyTemplate = document.getElementById('deidentify-template-dropdown').value; - } - var body = JSON.stringify(data) - - - // Send data to the server - fetch("/prompt", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: body - }).then(response => { - if (!response.ok) { - return response.json().then(errorData => { - throw new Error(errorData.errorMessage); - }); - } - return response.json(); - }).then(data => { - var content = data.response.text - if (data.response.warnings && data.response.warnings.length > 0) { - responseEl.classList.replace("response", "warning"); - content += "\n\nWarning: " + data.response.warnings.join("\n") + "\n" - } - responseEl.textContent = content; - }).catch(err => { - responseEl.classList.replace("response", "error"); - responseEl.textContent = err.message; - }).finally(() => enableForm(true)); - }); - - document.getElementById("toggle-dlp-filter-section").addEventListener("change", function() { - fetchDLPEnabled() - var inspectDropdown = document.getElementById('inspect-template-dropdown'); - var deidentifyDropdown = document.getElementById('deidentify-template-dropdown'); - - // Check the Inspect Template Dropdown - if (inspectDropdown.options.length <= 0) { - inspectDropdown.style.display = 'none'; // Hide Dropdown - document.getElementById('inspect-template-msg').style.display = 'block'; // Show Message - } else { - inspectDropdown.style.display = 'block'; // Show Dropdown - document.getElementById('inspect-template-msg').style.display = 'none'; // Hide Message - } - - // Check the De-identify Template Dropdown - if (deidentifyDropdown.options.length <= 0) { - deidentifyDropdown.style.display = 'none'; // Hide Dropdown - document.getElementById('deidentify-template-msg').style.display = 'block'; // Show Message - } else { - deidentifyDropdown.style.display = 'block'; // Show Dropdown - document.getElementById('deidentify-template-msg').style.display = 'none'; // Hide Message - } - }); - - document.getElementById("toggle-nlp-filter-section").addEventListener("change", function() { - fetchNLPEnabled() - }); -} -if (document.readyState != "loading") onReady(); -else document.addEventListener("DOMContentLoaded", onReady); - - -function enableForm(enabled) { - var promptEl = document.getElementById("prompt"); - promptEl.toggleAttribute("disabled", !enabled); - if (enabled) setTimeout(() => promptEl.focus(), 0); - - var submitEl = document.getElementById("submit"); - submitEl.toggleAttribute("disabled", !enabled); - submitEl.textContent = enabled ? "Submit" : "..."; -} - -function autoResizeTextarea() { - var textarea = document.getElementById('prompt'); - textarea.addEventListener('input', function() { - this.style.height = 'auto'; - this.style.height = this.scrollHeight + 'px'; - }); -} - -// Function to handle the visibility of filter section -function toggleNlpFilterSection(nlpEnabled) { - var filterOptions = document.getElementById("nlp-filter-section"); - var nlpCheckbox = document.getElementById('toggle-nlp-filter-section'); - - if (nlpEnabled && nlpCheckbox.checked) { - filterOptions.style.display = "block"; - } else { - filterOptions.style.display = "none"; - } -} - -function updateNLPValue() { - const rangeInput = document.getElementById('nlp-range'); - const valueDisplay = document.getElementById('nlp-value'); - - // Function to update the slider's display value and color - const updateSliderAppearance = (value) => { - // Update the display text - valueDisplay.textContent = value; - - // Determine the color based on the value - let color; - if (value <= 25) { - color = '#4285F4'; // Blue - } else if (value <= 50) { - color = '#34A853'; // Green - } else if (value <= 75) { - color = '#FBBC05'; // Yellow - } else { - color = '#EA4335'; // Red - } - - // Apply the color to the slider through a gradient - // This gradient visually fills the track up to the thumb's current position - const percentage = (value - rangeInput.min) / (rangeInput.max - rangeInput.min) * 100; - rangeInput.style.background = `linear-gradient(90deg, ${color} ${percentage}%, #ddd ${percentage}%)`; - rangeInput.style.setProperty('--thumb-color', color); - }; - - // Initialize the slider's appearance - updateSliderAppearance(rangeInput.value); - - // Update slider's appearance whenever its value changes - rangeInput.addEventListener('input', (event) => { - updateSliderAppearance(event.target.value); - }); -} - -function fetchNLPEnabled() { - fetch('/get_nlp_status') - .then(response => response.json()) - .then(data => { - var nlpEnabled = data.nlpEnabled; - - toggleNlpFilterSection(nlpEnabled); - }) - .catch(error => console.error('Error fetching NLP status:', error)) -} - -// Function to handle the visibility of filter section -function toggleDLPFilterSection(dlpEnabled) { - var filterOptions = document.getElementById("dlp-filter-section"); - var dlpCheckbox = document.getElementById('toggle-dlp-filter-section'); - if (dlpEnabled && dlpCheckbox.checked) { - filterOptions.style.display = "block"; - } else { - filterOptions.style.display = "none"; - } -} - - -function fetchDLPEnabled() { - fetch('/get_dlp_status') - .then(response => response.json()) - .then(data => { - var dlpEnabled = data.dlpEnabled; - - toggleDLPFilterSection(dlpEnabled); - }) - .catch(error => console.error('Error fetching DLP status:', error)) -} - -// Function to populate dropdowns -function populateDropdowns() { - fetch('/get_inspect_templates') - .then(response => response.json()) - .then(data => { - const inspectDropdown = document.getElementById('inspect-template-dropdown'); - data.forEach(template => { - let option = new Option(template, template); - inspectDropdown.add(option); - }); - }) - .catch(error => console.error('Error loading inspect templates:', error)); - - fetch('/get_deidentify_templates') - .then(response => response.json()) - .then(data => { - const deidentifyDropdown = document.getElementById('deidentify-template-dropdown'); - data.forEach(template => { - let option = new Option(template, template); - deidentifyDropdown.add(option); - }); - }) - .catch(error => console.error('Error loading deidentify templates:', error)); -} \ No newline at end of file +/* Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +function onReady() { + autoResizeTextarea(); + populateDropdowns(); + updateNLPValue(); + loadPreviousMessages(); + + document.getElementById("prompt").addEventListener("keydown", (e) => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + e.target.form.requestSubmit(); + } + }); + + // Handle the chat form submission + document.getElementById("form").addEventListener("submit", function (e) { + e.preventDefault(); + + var promptInput = document.getElementById("prompt"); + var prompt = promptInput.value; + if (prompt === "") { + return; + } + promptInput.value = ""; + + var chatEl = document.getElementById("chat"); + var promptEl = Object.assign(document.createElement("p"), { + classList: ["prompt"], + }); + promptEl.textContent = prompt; + chatEl.appendChild(promptEl); + + var responseEl = Object.assign(document.createElement("p"), { + classList: ["response"], + }); + chatEl.appendChild(responseEl); + chatEl.scrollTop = chatEl.scrollHeight; // Scroll to bottom + enableForm(false); + + // Collect filter data + let data = { + prompt: prompt, + }; + + if (document.getElementById("toggle-nlp-filter-section").checked) { + data.nlpFilterLevel = document.getElementById("nlp-range").value; + } + + if (document.getElementById("toggle-dlp-filter-section").checked) { + data.inspectTemplate = document.getElementById( + "inspect-template-dropdown" + ).value; + data.deidentifyTemplate = document.getElementById( + "deidentify-template-dropdown" + ).value; + } + var body = JSON.stringify(data); + + // Send data to the server + fetch("/prompt", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: body, + }) + .then((response) => { + if (!response.ok) { + return response.json().then((errorData) => { + throw new Error(errorData.errorMessage); + }); + } + return response.json(); + }) + .then((data) => { + var content = data.response.text; + if (data.response.warnings && data.response.warnings.length > 0) { + responseEl.classList.replace("response", "warning"); + content += "\n\nWarning: " + data.response.warnings.join("\n") + "\n"; + } + responseEl.textContent = content; + }) + .catch((err) => { + responseEl.classList.replace("response", "error"); + responseEl.textContent = err.message; + }) + .finally(() => enableForm(true)); + }); + + document + .getElementById("toggle-dlp-filter-section") + .addEventListener("change", function () { + fetchDLPEnabled(); + var inspectDropdown = document.getElementById( + "inspect-template-dropdown" + ); + var deidentifyDropdown = document.getElementById( + "deidentify-template-dropdown" + ); + + // Check the Inspect Template Dropdown + if (inspectDropdown.options.length <= 0) { + inspectDropdown.style.display = "none"; // Hide Dropdown + document.getElementById("inspect-template-msg").style.display = "block"; // Show Message + } else { + inspectDropdown.style.display = "block"; // Show Dropdown + document.getElementById("inspect-template-msg").style.display = "none"; // Hide Message + } + + // Check the De-identify Template Dropdown + if (deidentifyDropdown.options.length <= 0) { + deidentifyDropdown.style.display = "none"; // Hide Dropdown + document.getElementById("deidentify-template-msg").style.display = + "block"; // Show Message + } else { + deidentifyDropdown.style.display = "block"; // Show Dropdown + document.getElementById("deidentify-template-msg").style.display = + "none"; // Hide Message + } + }); + + document + .getElementById("toggle-nlp-filter-section") + .addEventListener("change", function () { + fetchNLPEnabled(); + }); +} +if (document.readyState != "loading") onReady(); +else document.addEventListener("DOMContentLoaded", onReady); + +function enableForm(enabled) { + var promptEl = document.getElementById("prompt"); + promptEl.toggleAttribute("disabled", !enabled); + if (enabled) setTimeout(() => promptEl.focus(), 0); + + var submitEl = document.getElementById("submit"); + submitEl.toggleAttribute("disabled", !enabled); + submitEl.textContent = enabled ? "Submit" : "..."; +} + +function autoResizeTextarea() { + var textarea = document.getElementById("prompt"); + textarea.addEventListener("input", function () { + this.style.height = "auto"; + this.style.height = this.scrollHeight + "px"; + }); +} + +// Function to handle the visibility of filter section +function toggleNlpFilterSection(nlpEnabled) { + var filterOptions = document.getElementById("nlp-filter-section"); + var nlpCheckbox = document.getElementById("toggle-nlp-filter-section"); + + if (nlpEnabled && nlpCheckbox.checked) { + filterOptions.style.display = "block"; + } else { + filterOptions.style.display = "none"; + } +} + +function updateNLPValue() { + const rangeInput = document.getElementById("nlp-range"); + const valueDisplay = document.getElementById("nlp-value"); + + // Function to update the slider's display value and color + const updateSliderAppearance = (value) => { + // Update the display text + valueDisplay.textContent = value; + + // Determine the color based on the value + let color; + if (value <= 25) { + color = "#4285F4"; // Blue + } else if (value <= 50) { + color = "#34A853"; // Green + } else if (value <= 75) { + color = "#FBBC05"; // Yellow + } else { + color = "#EA4335"; // Red + } + + // Apply the color to the slider through a gradient + // This gradient visually fills the track up to the thumb's current position + const percentage = + ((value - rangeInput.min) / (rangeInput.max - rangeInput.min)) * 100; + rangeInput.style.background = `linear-gradient(90deg, ${color} ${percentage}%, #ddd ${percentage}%)`; + rangeInput.style.setProperty("--thumb-color", color); + }; + + // Initialize the slider's appearance + updateSliderAppearance(rangeInput.value); + + // Update slider's appearance whenever its value changes + rangeInput.addEventListener("input", (event) => { + updateSliderAppearance(event.target.value); + }); +} + +function fetchNLPEnabled() { + fetch("/get_nlp_status") + .then((response) => response.json()) + .then((data) => { + var nlpEnabled = data.nlpEnabled; + + toggleNlpFilterSection(nlpEnabled); + }) + .catch((error) => console.error("Error fetching NLP status:", error)); +} + +// Function to handle the visibility of filter section +function toggleDLPFilterSection(dlpEnabled) { + var filterOptions = document.getElementById("dlp-filter-section"); + var dlpCheckbox = document.getElementById("toggle-dlp-filter-section"); + if (dlpEnabled && dlpCheckbox.checked) { + filterOptions.style.display = "block"; + } else { + filterOptions.style.display = "none"; + } +} + +function fetchDLPEnabled() { + fetch("/get_dlp_status") + .then((response) => response.json()) + .then((data) => { + var dlpEnabled = data.dlpEnabled; + + toggleDLPFilterSection(dlpEnabled); + }) + .catch((error) => console.error("Error fetching DLP status:", error)); +} + +// Function to populate dropdowns +function populateDropdowns() { + fetch("/get_inspect_templates") + .then((response) => response.json()) + .then((data) => { + const inspectDropdown = document.getElementById( + "inspect-template-dropdown" + ); + data.forEach((template) => { + let option = new Option(template, template); + inspectDropdown.add(option); + }); + }) + .catch((error) => console.error("Error loading inspect templates:", error)); + + fetch("/get_deidentify_templates") + .then((response) => response.json()) + .then((data) => { + const deidentifyDropdown = document.getElementById( + "deidentify-template-dropdown" + ); + data.forEach((template) => { + let option = new Option(template, template); + deidentifyDropdown.add(option); + }); + }) + .catch((error) => + console.error("Error loading deidentify templates:", error) + ); +} + +function loadPreviousMessages() { + fetch("/get_chat_history") + .then((response) => response.json()) + .then((data) => { + const { history_messages } = data; + var chatEl = document.getElementById("chat"); + history_messages.map(({ prompt, message }) => { + var promptEl = Object.assign(document.createElement("p"), { + classList: ["previous_message"], + }); + promptEl.textContent = `${prompt} : ${message}`; + chatEl.appendChild(promptEl); + }); + }) + .catch((error) => + console.error("Error getting previous chat messages:", error) + ); +} diff --git a/applications/rag/frontend/container/application/static/styles.css b/applications/rag/frontend/container/application/static/styles.css index 4d6ae030a..d2bf11352 100644 --- a/applications/rag/frontend/container/application/static/styles.css +++ b/applications/rag/frontend/container/application/static/styles.css @@ -1,199 +1,209 @@ -/* Copyright 2024 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -/* styles.css */ -body { - font-family: Arial, sans-serif; - text-align: center; - - max-width: 2000px; - margin: 0 auto; -} - -.content-container { - display: flex; - max-width: 2000px; - margin: 0 auto; -} - -#chat-and-form { - width: 70%; -} - -#chat { - height: 80vh; - padding: 10px; - margin-top: 20px; - margin-bottom: 10px; - display: block; - - text-align: left; - overflow-y: scroll; - - border: 1px solid #222; - border-radius: 2px; -} - -#chat p { - white-space: pre-wrap; - - padding: 8px; - border-radius: 4px; -} - -#chat p:first-child { - margin-top: 0; -} - -p.prompt { - background-color: rgba(50, 255, 0, 0.1); -} - -p.prompt::before { - content: 'Prompt: '; - font-weight: bold; -} - -p.instruct { - background-color: rgba(255, 255, 0, 0.2); -} - -p.instruct::before { - content: 'Instructions: '; - font-weight: bold; -} - -p.response { - background-color: rgba(0, 134, 255, 0.1); -} - -p.response::before { - content: 'Response: '; - font-weight: bold; -} - -p.warning { - background-color: rgba(255, 229, 100, 0.1); -} - -p.error::before { - content: 'Warning: '; - font-weight: bold; -} - -p.error { - background-color: rgba(255, 0, 0, 0.1); -} - -p.error::before { - content: 'Error: '; - font-weight: bold; -} - -form { - display: flex; - width: 100%; -} - -#prompt { - flex: 1; - padding: 10px; - margin-right: 10px; - border-radius: 2px; -} - -input[type="submit"] { - padding: 10px 20px; - background-color: rgb(76,175,80); - color: white; - border: none; - border-radius: 2px; -} - -input[type="submit"]:disabled { - background-color: rgba(0, 0, 0, 0.2); -} - -/* Filter Section Styles */ -#filter-section { - width: 30%; - margin-right: 2%; /* Space between filter section and the next section, adjust as needed */ - border: 1px solid #ccc; /* Optional: adds a border around the filter section */ - padding: 20px; /* Adds some space inside the filter section */ - box-shadow: 0 2px 4px rgba(0,0,0,0.1); /* Optional: adds a slight shadow for depth */ -} - -#dlp-filter-section, -#nlp-filter-section { - margin-top: 20px; -} - -select { - width: 100%; - padding: 10px; - margin-top: 10px; - border: 1px solid #ccc; /* Adds a border to the dropdowns */ - border-radius: 4px; /* Rounds the corners of the dropdowns */ -} - -label { - display: block; - margin-top: 15px; - font-weight: bold; -} - -input[type="checkbox"] { - margin-right: 10px; /* Space between checkbox and its label */ -} - -#nlp-range { - width: 100%; - margin-top: 10px; /* Adjusted for consistency */ -} - -input[type="range"] { - -webkit-appearance: none; - appearance: none; - width: 100%; - height: 8px; - background: #ddd; - outline: none; - opacity: 1; - -webkit-transition: .2s; - transition: opacity .2s; -} - -input[type="range"]::-webkit-slider-thumb { - -webkit-appearance: none; - appearance: none; - width: 25px; - height: 25px; - background: var(--thumb-color); - cursor: pointer; -} - -input[type="range"]::-moz-range-thumb { - width: 25px; - height: 25px; - background: var(--thumb-color); - cursor: pointer; -} - -input[type="range"]::-ms-thumb { - width: 25px; - height: 25px; - background: var(--thumb-color); - cursor: pointer; +/* Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* styles.css */ +body { + font-family: Arial, sans-serif; + text-align: center; + + max-width: 2000px; + margin: 0 auto; +} + +.content-container { + display: flex; + max-width: 2000px; + margin: 0 auto; +} + +#chat-and-form { + width: 70%; +} + +#chat { + height: 80vh; + padding: 10px; + margin-top: 20px; + margin-bottom: 10px; + display: block; + + text-align: left; + overflow-y: scroll; + + border: 1px solid #222; + border-radius: 2px; +} + +#chat p { + white-space: pre-wrap; + + padding: 8px; + border-radius: 4px; +} + +#chat p:first-child { + margin-top: 0; +} + +p.prompt { + background-color: rgba(50, 255, 0, 0.1); +} + +p.prompt::before { + content: 'Prompt: '; + font-weight: bold; +} + +p.instruct{ + background-color: rgba(255, 255, 0, 0.2); +} + + +p.instruct::before { + content: 'Instructions: '; + font-weight: bold; +} + +p.previous_message { + background-color: rgba(73, 240, 240, 0.753); +} + +p.previous_message::before { + content: 'Previous message: '; + font-weight: bold; +} + +p.response { + background-color: rgba(0, 134, 255, 0.1); +} + +p.response::before { + content: 'Response: '; + font-weight: bold; +} + +p.warning { + background-color: rgba(255, 229, 100, 0.1); +} + +p.error::before { + content: 'Warning: '; + font-weight: bold; +} + +p.error { + background-color: rgba(255, 0, 0, 0.1); +} + +p.error::before { + content: 'Error: '; + font-weight: bold; +} + +form { + display: flex; + width: 100%; +} + +#prompt { + flex: 1; + padding: 10px; + margin-right: 10px; + border-radius: 2px; +} + +input[type="submit"] { + padding: 10px 20px; + background-color: rgb(76,175,80); + color: white; + border: none; + border-radius: 2px; +} + +input[type="submit"]:disabled { + background-color: rgba(0, 0, 0, 0.2); +} + +/* Filter Section Styles */ +#filter-section { + width: 30%; + margin-right: 2%; /* Space between filter section and the next section, adjust as needed */ + border: 1px solid #ccc; /* Optional: adds a border around the filter section */ + padding: 20px; /* Adds some space inside the filter section */ + box-shadow: 0 2px 4px rgba(0,0,0,0.1); /* Optional: adds a slight shadow for depth */ +} + +#dlp-filter-section, +#nlp-filter-section { + margin-top: 20px; +} + +select { + width: 100%; + padding: 10px; + margin-top: 10px; + border: 1px solid #ccc; /* Adds a border to the dropdowns */ + border-radius: 4px; /* Rounds the corners of the dropdowns */ +} + +label { + display: block; + margin-top: 15px; + font-weight: bold; +} + +input[type="checkbox"] { + margin-right: 10px; /* Space between checkbox and its label */ +} + +#nlp-range { + width: 100%; + margin-top: 10px; /* Adjusted for consistency */ +} + +input[type="range"] { + -webkit-appearance: none; + appearance: none; + width: 100%; + height: 8px; + background: #ddd; + outline: none; + opacity: 1; + -webkit-transition: .2s; + transition: opacity .2s; +} + +input[type="range"]::-webkit-slider-thumb { + -webkit-appearance: none; + appearance: none; + width: 25px; + height: 25px; + background: var(--thumb-color); + cursor: pointer; +} + +input[type="range"]::-moz-range-thumb { + width: 25px; + height: 25px; + background: var(--thumb-color); + cursor: pointer; +} + +input[type="range"]::-ms-thumb { + width: 25px; + height: 25px; + background: var(--thumb-color); + cursor: pointer; } \ No newline at end of file diff --git a/applications/rag/frontend/container/main.py b/applications/rag/frontend/container/main.py index 32a880f89..987555552 100644 --- a/applications/rag/frontend/container/main.py +++ b/applications/rag/frontend/container/main.py @@ -1,186 +1,191 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import uuid -import traceback -import logging as log - -import google.cloud.logging as logging - -from flask import render_template, request, jsonify, session, redirect, url_for -from datetime import datetime, timedelta, timezone - -from application import create_app -from application.rai import ( - dlp_filter, -) # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp -from application.rai import ( - nlp_filter, -) # https://cloud.google.com/natural-language/docs/moderating-text - -from application.rag_langchain.rag_chain import ( - clear_chat_history, - create_chain, - take_chat_turn, - get_chat_history, -) - -log.basicConfig(level=log.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - -SESSION_TIMEOUT_MINUTES = 10 - -# Setup logging -logging_client = logging.Client() -logging_client.setup_logging() - -app = create_app() - -# Create llm chain -llm_chain = create_chain() - - -@app.before_request -def check_new_session(): - if "session_id" not in session: - # instantiate a new session using a generated UUID - session_id = str(uuid.uuid4()) - session["session_id"] = session_id - - -@app.before_request -def check_inactivity(): - # Inactivity cleanup - if "last_activity" in session: - time_elapsed = datetime.now(timezone.utc) - session["last_activity"] - - if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): - print("Session inactive: Cleaning up resources...") - session_id = session["session_id"] - # TODO: implement garbage collection process for idle sessions that have timed out - clear_chat_history(session_id) - session.clear() - - # Always update the 'last_activity' data - session["last_activity"] = datetime.now(timezone.utc) - - -@app.route("/get_nlp_status", methods=["GET"]) -def get_nlp_status(): - nlp_enabled = nlp_filter.is_nlp_api_enabled() - return jsonify({"nlpEnabled": nlp_enabled}) - - -@app.route("/get_dlp_status", methods=["GET"]) -def get_dlp_status(): - dlp_enabled = dlp_filter.is_dlp_api_enabled() - return jsonify({"dlpEnabled": dlp_enabled}) - - -@app.route("/get_inspect_templates") -def get_inspect_templates(): - return jsonify(dlp_filter.list_inspect_templates_from_parent()) - - -@app.route("/get_deidentify_templates") -def get_deidentify_templates(): - return jsonify(dlp_filter.list_deidentify_templates_from_parent()) - - -@app.route("/get_chat_history", methods=["GET"]) -def get_chat_history_endpoint(): - try: - session_id = session.get("session_id") - history = get_chat_history(session_id) - log.info(history) - - response = jsonify({"history_messages": []}) - response.status_code = 200 - - return response - - except Exception as err: - log.info(f"exception from llm: {err}") - traceback.print_exc() - error_traceback = traceback.format_exc() - response = jsonify( - { - "error": "An error occurred", - "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", - } - ) - response.status_code = 500 - return response - - -@app.route("/") -def index(): - return render_template("index.html") - - -@app.route("/prompt", methods=["POST"]) -def handlePrompt(): - # TODO on page refresh, load chat history into browser. - session["last_activity"] = datetime.now(timezone.utc) - data = request.get_json() - warnings = [] - - if "prompt" not in data: - return "missing required prompt", 400 - - user_prompt = data["prompt"] - log.info(f"handle user prompt: {user_prompt}") - - try: - session_id = session.get("session_id") - if not session_id: - return redirect(url_for("index")) - response = {} - result = take_chat_turn(llm_chain, session_id, user_prompt) - log.info("After the result") - log.info(result) - response["text"] = result - - # TODO: enable filtering in chain - if "nlpFilterLevel" in data: - if nlp_filter.is_content_inappropriate( - response["text"], data["nlpFilterLevel"] - ): - response["text"] = "The response is deemed inappropriate for display." - return {"response": response} - if "inspectTemplate" in data and "deidentifyTemplate" in data: - inspect_template_path = data["inspectTemplate"] - deidentify_template_path = data["deidentifyTemplate"] - if inspect_template_path != "" and deidentify_template_path != "": - # filter the output with inspect setting. Customer can pick any category from https://cloud.google.com/dlp/docs/concepts-infotypes - response["text"] = dlp_filter.inspect_content( - inspect_template_path, deidentify_template_path, response["text"] - ) - - if warnings: - response["warnings"] = warnings - log.info(f"response: {response}") - return {"response": response} - except Exception as err: - log.info(f"exception from llm: {err}") - traceback.print_exc() - error_traceback = traceback.format_exc() - response = jsonify( - { - "warnings": warnings, - "error": "An error occurred", - "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", - } - ) - response.status_code = 500 - return response +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +import traceback +import logging as log + +import google.cloud.logging as logging + +from flask import render_template, request, jsonify, session, redirect, url_for +from datetime import datetime, timedelta, timezone + +from application import create_app +from application.rai import ( + dlp_filter, +) # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp +from application.rai import ( + nlp_filter, +) # https://cloud.google.com/natural-language/docs/moderating-text + +from application.rag_langchain.rag_chain import ( + clear_chat_history, + create_chain, + take_chat_turn, + get_chat_history, +) + +log.basicConfig(level=log.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +SESSION_TIMEOUT_MINUTES = 20 + +# Setup logging +logging_client = logging.Client() +logging_client.setup_logging() + +app = create_app() + +# Create llm chain +llm_chain = create_chain() + + +@app.before_request +def check_new_session(): + if "session_id" not in session: + # instantiate a new session using a generated UUID + session_id = str(uuid.uuid4()) + session["session_id"] = session_id + + +@app.before_request +def check_inactivity(): + # Inactivity cleanup + if "last_activity" in session: + time_elapsed = datetime.now(timezone.utc) - session["last_activity"] + + if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): + print("Session inactive: Cleaning up resources...") + session_id = session["session_id"] + # TODO: implement garbage collection process for idle sessions that have timed out + clear_chat_history(session_id) + session.clear() + + # Always update the 'last_activity' data + session["last_activity"] = datetime.now(timezone.utc) + + +@app.route("/get_nlp_status", methods=["GET"]) +def get_nlp_status(): + nlp_enabled = nlp_filter.is_nlp_api_enabled() + return jsonify({"nlpEnabled": nlp_enabled}) + + +@app.route("/get_dlp_status", methods=["GET"]) +def get_dlp_status(): + dlp_enabled = dlp_filter.is_dlp_api_enabled() + return jsonify({"dlpEnabled": dlp_enabled}) + + +@app.route("/get_inspect_templates") +def get_inspect_templates(): + return jsonify(dlp_filter.list_inspect_templates_from_parent()) + + +@app.route("/get_deidentify_templates") +def get_deidentify_templates(): + return jsonify(dlp_filter.list_deidentify_templates_from_parent()) + + +@app.route("/get_chat_history", methods=["GET"]) +def get_chat_history_endpoint(): + try: + session_id = session.get("session_id") + if not session_id: + return redirect(url_for("index")) + + history = get_chat_history(session_id) + + messages_response = [] + for message in history.messages: + data = {"prompt": message.type, "message": message.content} + messages_response.append(data) + + response = jsonify({"history_messages": messages_response}) + response.status_code = 200 + + return response + + except Exception as err: + log.info(f"exception from llm: {err}") + traceback.print_exc() + error_traceback = traceback.format_exc() + response = jsonify( + { + "error": "An error occurred", + "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", + } + ) + response.status_code = 500 + return response + + +@app.route("/") +def index(): + return render_template("index.html") + + +@app.route("/prompt", methods=["POST"]) +def handlePrompt(): + # TODO on page refresh, load chat history into browser. + session["last_activity"] = datetime.now(timezone.utc) + data = request.get_json() + warnings = [] + + if "prompt" not in data: + return "missing required prompt", 400 + + user_prompt = data["prompt"] + log.info(f"handle user prompt: {user_prompt}") + + try: + session_id = session.get("session_id") + if not session_id: + return redirect(url_for("index")) + response = {} + result = take_chat_turn(llm_chain, session_id, user_prompt) + response["text"] = result + + # TODO: enable filtering in chain + if "nlpFilterLevel" in data: + if nlp_filter.is_content_inappropriate( + response["text"], data["nlpFilterLevel"] + ): + response["text"] = "The response is deemed inappropriate for display." + return {"response": response} + if "inspectTemplate" in data and "deidentifyTemplate" in data: + inspect_template_path = data["inspectTemplate"] + deidentify_template_path = data["deidentifyTemplate"] + if inspect_template_path != "" and deidentify_template_path != "": + # filter the output with inspect setting. Customer can pick any category from https://cloud.google.com/dlp/docs/concepts-infotypes + response["text"] = dlp_filter.inspect_content( + inspect_template_path, deidentify_template_path, response["text"] + ) + + if warnings: + response["warnings"] = warnings + return {"response": response} + + except Exception as err: + log.info(f"exception from llm: {err}") + traceback.print_exc() + error_traceback = traceback.format_exc() + response = jsonify( + { + "warnings": warnings, + "error": "An error occurred", + "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", + } + ) + response.status_code = 500 + return response diff --git a/cloud_sql.py b/cloud_sql.py deleted file mode 100644 index 1c0da565b..000000000 --- a/cloud_sql.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import uuid -import logging - -from typing import List, Optional, Iterable, Any - -from langchain_core.vectorstores import VectorStore -from langchain_core.embeddings import Embeddings -from langchain_core.documents import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter - -from langchain_google_cloud_sql_pg import PostgresVectorStore - - -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") -CHUNK_SIZE = 1000 -CHUNK_OVERLAP = 10 - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - - -class CloudSQLVectorStore(VectorStore): - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ): - raise NotImplementedError - - def __init__(self, embedding_provider, engine): - self.vector_store = PostgresVectorStore.create_sync( - engine=engine, - embedding_service=embedding_provider, - table_name=VECTOR_EMBEDDINGS_TABLE_NAME, - ) - self.splitter = RecursiveCharacterTextSplitter( - chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len - ) - self.embeddings_service = embedding_provider - - # TODO implement - def add_texts( - self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any - ) -> List[str]: - try: - splits = self.splitter.split_documents(texts) - ids = [str(uuid.uuid4()) for _ in range(len(splits))] - self.vector_store.add_documents(splits, ids) - except Exception as e: - logging.info(f"Error: {e}") - raise e - - # TODO implement similarity search with cosine similarity threshold - - def similarity_search( - self, query: dict, k: int = 4, **kwargs: Any - ) -> List[Document]: - try: - - query_input = query.get("input") - query_vector = self.embeddings_service.embed_query(query_input) - docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) - return docs - - except Exception as err: - raise Exception(f"General error: {err}") diff --git a/dlp_filter.py b/dlp_filter.py deleted file mode 100644 index ad69a30c3..000000000 --- a/dlp_filter.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import logging - -import google.cloud.dlp - -from . import retry - -# Convert the project id into a full resource id. -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "NULL") -parent = f"projects/{GCP_PROJECT_ID}" - -# Instantiate a dlp client. -dlp_client = google.cloud.dlp_v2.DlpServiceClient() - -logging.basicConfig( - level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s" -) - - -def is_dlp_api_enabled(): - if parent == "NULL": - return False - # Check if the DLP API is enabled - try: - dlp_client.list_info_types( - request={"parent": "en-US"}, retry=retry.retry_policy - ) - return True - except Exception as e: - print(f"Error: {e}") - return False - - -def list_inspect_templates_from_parent(): - try: - # Initialize request argument(s) - request = google.cloud.dlp_v2.ListInspectTemplatesRequest( - parent=parent, - ) - - # Make the request - page_result = dlp_client.list_inspect_templates( - request=request, retry=retry.retry_policy - ) - - name_list = [] - # Handle the response - for response in page_result: - name_list.append(response.name) - return name_list - except Exception as e: - logging.error(e) - raise e - - -def get_inspect_templates_from_name(name): - try: - request = google.cloud.dlp_v2.GetInspectTemplateRequest( - name=name, - ) - - return dlp_client.get_inspect_template(request=request) - except Exception as e: - logging.error(e) - raise e - - -def list_deidentify_templates_from_parent(): - try: - # Initialize request argument(s) - request = google.cloud.dlp_v2.ListDeidentifyTemplatesRequest( - parent=parent, - ) - - # Make the request - page_result = dlp_client.list_deidentify_templates(request=request) - - name_list = [] - # Handle the response - for response in page_result: - name_list.append(response.name) - return name_list - - except Exception as e: - logging.error(e) - raise e - - -def get_deidentify_templates_from_name(name): - try: - request = google.cloud.dlp_v2.GetDeidentifyTemplateRequest( - name=name, - ) - - return dlp_client.get_deidentify_template( - request=request, retry=retry.retry_policy - ) - except Exception as e: - logging.error(e) - raise e - - -def inspect_content(inspect_template_path, deidentify_template_path, input): - try: - inspect_templates = get_inspect_templates_from_name(inspect_template_path) - deidentify_template = get_deidentify_templates_from_name( - deidentify_template_path - ) - - # Construct item - item = {"value": input} - - # Call the API - response = dlp_client.deidentify_content( - request={ - "parent": parent, - "deidentify_config": deidentify_template.deidentify_config, - "inspect_config": inspect_templates.inspect_config, - "item": item, - }, - retry=retry.retry_policy, - ) - - # Print out the results. - print(response.item.value) - return response.item.value - except Exception as e: - logging.error(e) - raise e diff --git a/main.py b/main.py deleted file mode 100644 index 987555552..000000000 --- a/main.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import uuid -import traceback -import logging as log - -import google.cloud.logging as logging - -from flask import render_template, request, jsonify, session, redirect, url_for -from datetime import datetime, timedelta, timezone - -from application import create_app -from application.rai import ( - dlp_filter, -) # Google's Cloud Data Loss Prevention (DLP) API. https://cloud.google.com/security/products/dlp -from application.rai import ( - nlp_filter, -) # https://cloud.google.com/natural-language/docs/moderating-text - -from application.rag_langchain.rag_chain import ( - clear_chat_history, - create_chain, - take_chat_turn, - get_chat_history, -) - -log.basicConfig(level=log.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - -SESSION_TIMEOUT_MINUTES = 20 - -# Setup logging -logging_client = logging.Client() -logging_client.setup_logging() - -app = create_app() - -# Create llm chain -llm_chain = create_chain() - - -@app.before_request -def check_new_session(): - if "session_id" not in session: - # instantiate a new session using a generated UUID - session_id = str(uuid.uuid4()) - session["session_id"] = session_id - - -@app.before_request -def check_inactivity(): - # Inactivity cleanup - if "last_activity" in session: - time_elapsed = datetime.now(timezone.utc) - session["last_activity"] - - if time_elapsed > timedelta(minutes=SESSION_TIMEOUT_MINUTES): - print("Session inactive: Cleaning up resources...") - session_id = session["session_id"] - # TODO: implement garbage collection process for idle sessions that have timed out - clear_chat_history(session_id) - session.clear() - - # Always update the 'last_activity' data - session["last_activity"] = datetime.now(timezone.utc) - - -@app.route("/get_nlp_status", methods=["GET"]) -def get_nlp_status(): - nlp_enabled = nlp_filter.is_nlp_api_enabled() - return jsonify({"nlpEnabled": nlp_enabled}) - - -@app.route("/get_dlp_status", methods=["GET"]) -def get_dlp_status(): - dlp_enabled = dlp_filter.is_dlp_api_enabled() - return jsonify({"dlpEnabled": dlp_enabled}) - - -@app.route("/get_inspect_templates") -def get_inspect_templates(): - return jsonify(dlp_filter.list_inspect_templates_from_parent()) - - -@app.route("/get_deidentify_templates") -def get_deidentify_templates(): - return jsonify(dlp_filter.list_deidentify_templates_from_parent()) - - -@app.route("/get_chat_history", methods=["GET"]) -def get_chat_history_endpoint(): - try: - session_id = session.get("session_id") - if not session_id: - return redirect(url_for("index")) - - history = get_chat_history(session_id) - - messages_response = [] - for message in history.messages: - data = {"prompt": message.type, "message": message.content} - messages_response.append(data) - - response = jsonify({"history_messages": messages_response}) - response.status_code = 200 - - return response - - except Exception as err: - log.info(f"exception from llm: {err}") - traceback.print_exc() - error_traceback = traceback.format_exc() - response = jsonify( - { - "error": "An error occurred", - "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", - } - ) - response.status_code = 500 - return response - - -@app.route("/") -def index(): - return render_template("index.html") - - -@app.route("/prompt", methods=["POST"]) -def handlePrompt(): - # TODO on page refresh, load chat history into browser. - session["last_activity"] = datetime.now(timezone.utc) - data = request.get_json() - warnings = [] - - if "prompt" not in data: - return "missing required prompt", 400 - - user_prompt = data["prompt"] - log.info(f"handle user prompt: {user_prompt}") - - try: - session_id = session.get("session_id") - if not session_id: - return redirect(url_for("index")) - response = {} - result = take_chat_turn(llm_chain, session_id, user_prompt) - response["text"] = result - - # TODO: enable filtering in chain - if "nlpFilterLevel" in data: - if nlp_filter.is_content_inappropriate( - response["text"], data["nlpFilterLevel"] - ): - response["text"] = "The response is deemed inappropriate for display." - return {"response": response} - if "inspectTemplate" in data and "deidentifyTemplate" in data: - inspect_template_path = data["inspectTemplate"] - deidentify_template_path = data["deidentifyTemplate"] - if inspect_template_path != "" and deidentify_template_path != "": - # filter the output with inspect setting. Customer can pick any category from https://cloud.google.com/dlp/docs/concepts-infotypes - response["text"] = dlp_filter.inspect_content( - inspect_template_path, deidentify_template_path, response["text"] - ) - - if warnings: - response["warnings"] = warnings - return {"response": response} - - except Exception as err: - log.info(f"exception from llm: {err}") - traceback.print_exc() - error_traceback = traceback.format_exc() - response = jsonify( - { - "warnings": warnings, - "error": "An error occurred", - "errorMessage": f"Error: {err}\nTraceback:\n{error_traceback}", - } - ) - response.status_code = 500 - return response diff --git a/nlp_filter.py b/nlp_filter.py deleted file mode 100644 index 3b384a910..000000000 --- a/nlp_filter.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import logging - -import google.cloud.language_v1 as language - -from . import retry - -# Convert the project id into a full resource id. -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "NULL") -parent = f"projects/{GCP_PROJECT_ID}" - -# Instantiate a nlp client. -nature_language_client = language.LanguageServiceClient() - -logging.basicConfig( - level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s" -) - - -def is_nlp_api_enabled(): - if parent == "NULL": - return False - # Check if the DLP API is enabled - try: - sum_moderation_confidences("test") - return True - except Exception as e: - print(f"Error: {e}") - return False - - -def sum_moderation_confidences(text): - try: - document = language.types.Document( - content=text, type_=language.types.Document.Type.PLAIN_TEXT - ) - - request = language.ModerateTextRequest( - document=document, - ) - # Detects the sentiment of the text - response = nature_language_client.moderate_text( - request=request, retry=retry.retry_policy - ) - print(f"get response: {response}") - # Parse response and sum the confidences of moderation, the categories are from https://cloud.google.com/natural-language/docs/moderating-text - largest_confidence = 0.0 - excluding_names = ["Health", "Politics", "Finance", "Legal"] - for category in response.moderation_categories: - if category.name in excluding_names: - continue - if category.confidence > largest_confidence: - largest_confidence = category.confidence - - print(f"largest confidence is: {largest_confidence}") - return int(largest_confidence * 100) - except Exception as e: - logging.error(e) - raise e - - -def is_content_inappropriate(text, nlp_filter_level): - try: - return sum_moderation_confidences(text) > (100 - int(nlp_filter_level)) - except Exception as e: - logging.error(e) - raise e diff --git a/rag_chain.py b/rag_chain.py deleted file mode 100644 index f290021ff..000000000 --- a/rag_chain.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import RunnableParallel, RunnableLambda -from langchain_core.runnables.history import RunnableWithMessageHistory - -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_google_cloud_sql_pg import PostgresChatMessageHistory - - -from application.cloud_sql.cloud_sql import ( - CHAT_HISTORY_TABLE_NAME, - create_sync_postgres_engine, -) -from application.rag_langchain.huggingface_inference_model import ( - HuggingFaceCustomChatModel, -) -from application.vector_storages import CloudSQLVectorStore - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - -QUESTION = "input" -HISTORY = "chat_history" -CONTEXT = "context" - -SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings - -template_str = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. -Improve upon your previous answers using History, a list of messages. -Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. -Stick to the facts by basing your answers off of the Context provided. -Be brief in answering. -\n\n -Context: {context} -""" - -prompt = ChatPromptTemplate.from_messages( - [ - ("system", template_str), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] -) - -engine = create_sync_postgres_engine() - - -def get_chat_history(session_id: str) -> PostgresChatMessageHistory: - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - - logging.info( - f"Retrieving history for session {session_id} with {len(history.messages)}" - ) - return history - - -def clear_chat_history(session_id: str): - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - history.clear() - - -def create_chain() -> RunnableWithMessageHistory: - model = HuggingFaceCustomChatModel() - - langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CloudSQLVectorStore(langchain_embed, engine) - - retriever = vector_store.as_retriever() - - setup_and_retrieval = RunnableParallel( - { - "context": retriever, - QUESTION: RunnableLambda(lambda d: d[QUESTION]), - HISTORY: RunnableLambda(lambda d: d[HISTORY]), - } - ) - - chain = setup_and_retrieval | prompt | model - chain_with_history = RunnableWithMessageHistory( - chain, - get_chat_history, - input_messages_key=QUESTION, - history_messages_key=HISTORY, - output_messages_key="output", - ) - return chain_with_history - - -def take_chat_turn( - chain: RunnableWithMessageHistory, session_id: str, query_text: str -) -> str: - config = {"configurable": {"session_id": session_id}} - result = chain.invoke({"input": query_text}, config=config) - return result diff --git a/script.js b/script.js deleted file mode 100644 index 6fec15f99..000000000 --- a/script.js +++ /dev/null @@ -1,292 +0,0 @@ -/* Copyright 2024 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -function onReady() { - autoResizeTextarea(); - populateDropdowns(); - updateNLPValue(); - loadPreviousMessages(); - - document.getElementById("prompt").addEventListener("keydown", (e) => { - if (e.key === "Enter" && !e.shiftKey) { - e.preventDefault(); - e.target.form.requestSubmit(); - } - }); - - // Handle the chat form submission - document.getElementById("form").addEventListener("submit", function (e) { - e.preventDefault(); - - var promptInput = document.getElementById("prompt"); - var prompt = promptInput.value; - if (prompt === "") { - return; - } - promptInput.value = ""; - - var chatEl = document.getElementById("chat"); - var promptEl = Object.assign(document.createElement("p"), { - classList: ["prompt"], - }); - promptEl.textContent = prompt; - chatEl.appendChild(promptEl); - - var responseEl = Object.assign(document.createElement("p"), { - classList: ["response"], - }); - chatEl.appendChild(responseEl); - chatEl.scrollTop = chatEl.scrollHeight; // Scroll to bottom - enableForm(false); - - // Collect filter data - let data = { - prompt: prompt, - }; - - if (document.getElementById("toggle-nlp-filter-section").checked) { - data.nlpFilterLevel = document.getElementById("nlp-range").value; - } - - if (document.getElementById("toggle-dlp-filter-section").checked) { - data.inspectTemplate = document.getElementById( - "inspect-template-dropdown" - ).value; - data.deidentifyTemplate = document.getElementById( - "deidentify-template-dropdown" - ).value; - } - var body = JSON.stringify(data); - - // Send data to the server - fetch("/prompt", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: body, - }) - .then((response) => { - if (!response.ok) { - return response.json().then((errorData) => { - throw new Error(errorData.errorMessage); - }); - } - return response.json(); - }) - .then((data) => { - var content = data.response.text; - if (data.response.warnings && data.response.warnings.length > 0) { - responseEl.classList.replace("response", "warning"); - content += "\n\nWarning: " + data.response.warnings.join("\n") + "\n"; - } - responseEl.textContent = content; - }) - .catch((err) => { - responseEl.classList.replace("response", "error"); - responseEl.textContent = err.message; - }) - .finally(() => enableForm(true)); - }); - - document - .getElementById("toggle-dlp-filter-section") - .addEventListener("change", function () { - fetchDLPEnabled(); - var inspectDropdown = document.getElementById( - "inspect-template-dropdown" - ); - var deidentifyDropdown = document.getElementById( - "deidentify-template-dropdown" - ); - - // Check the Inspect Template Dropdown - if (inspectDropdown.options.length <= 0) { - inspectDropdown.style.display = "none"; // Hide Dropdown - document.getElementById("inspect-template-msg").style.display = "block"; // Show Message - } else { - inspectDropdown.style.display = "block"; // Show Dropdown - document.getElementById("inspect-template-msg").style.display = "none"; // Hide Message - } - - // Check the De-identify Template Dropdown - if (deidentifyDropdown.options.length <= 0) { - deidentifyDropdown.style.display = "none"; // Hide Dropdown - document.getElementById("deidentify-template-msg").style.display = - "block"; // Show Message - } else { - deidentifyDropdown.style.display = "block"; // Show Dropdown - document.getElementById("deidentify-template-msg").style.display = - "none"; // Hide Message - } - }); - - document - .getElementById("toggle-nlp-filter-section") - .addEventListener("change", function () { - fetchNLPEnabled(); - }); -} -if (document.readyState != "loading") onReady(); -else document.addEventListener("DOMContentLoaded", onReady); - -function enableForm(enabled) { - var promptEl = document.getElementById("prompt"); - promptEl.toggleAttribute("disabled", !enabled); - if (enabled) setTimeout(() => promptEl.focus(), 0); - - var submitEl = document.getElementById("submit"); - submitEl.toggleAttribute("disabled", !enabled); - submitEl.textContent = enabled ? "Submit" : "..."; -} - -function autoResizeTextarea() { - var textarea = document.getElementById("prompt"); - textarea.addEventListener("input", function () { - this.style.height = "auto"; - this.style.height = this.scrollHeight + "px"; - }); -} - -// Function to handle the visibility of filter section -function toggleNlpFilterSection(nlpEnabled) { - var filterOptions = document.getElementById("nlp-filter-section"); - var nlpCheckbox = document.getElementById("toggle-nlp-filter-section"); - - if (nlpEnabled && nlpCheckbox.checked) { - filterOptions.style.display = "block"; - } else { - filterOptions.style.display = "none"; - } -} - -function updateNLPValue() { - const rangeInput = document.getElementById("nlp-range"); - const valueDisplay = document.getElementById("nlp-value"); - - // Function to update the slider's display value and color - const updateSliderAppearance = (value) => { - // Update the display text - valueDisplay.textContent = value; - - // Determine the color based on the value - let color; - if (value <= 25) { - color = "#4285F4"; // Blue - } else if (value <= 50) { - color = "#34A853"; // Green - } else if (value <= 75) { - color = "#FBBC05"; // Yellow - } else { - color = "#EA4335"; // Red - } - - // Apply the color to the slider through a gradient - // This gradient visually fills the track up to the thumb's current position - const percentage = - ((value - rangeInput.min) / (rangeInput.max - rangeInput.min)) * 100; - rangeInput.style.background = `linear-gradient(90deg, ${color} ${percentage}%, #ddd ${percentage}%)`; - rangeInput.style.setProperty("--thumb-color", color); - }; - - // Initialize the slider's appearance - updateSliderAppearance(rangeInput.value); - - // Update slider's appearance whenever its value changes - rangeInput.addEventListener("input", (event) => { - updateSliderAppearance(event.target.value); - }); -} - -function fetchNLPEnabled() { - fetch("/get_nlp_status") - .then((response) => response.json()) - .then((data) => { - var nlpEnabled = data.nlpEnabled; - - toggleNlpFilterSection(nlpEnabled); - }) - .catch((error) => console.error("Error fetching NLP status:", error)); -} - -// Function to handle the visibility of filter section -function toggleDLPFilterSection(dlpEnabled) { - var filterOptions = document.getElementById("dlp-filter-section"); - var dlpCheckbox = document.getElementById("toggle-dlp-filter-section"); - if (dlpEnabled && dlpCheckbox.checked) { - filterOptions.style.display = "block"; - } else { - filterOptions.style.display = "none"; - } -} - -function fetchDLPEnabled() { - fetch("/get_dlp_status") - .then((response) => response.json()) - .then((data) => { - var dlpEnabled = data.dlpEnabled; - - toggleDLPFilterSection(dlpEnabled); - }) - .catch((error) => console.error("Error fetching DLP status:", error)); -} - -// Function to populate dropdowns -function populateDropdowns() { - fetch("/get_inspect_templates") - .then((response) => response.json()) - .then((data) => { - const inspectDropdown = document.getElementById( - "inspect-template-dropdown" - ); - data.forEach((template) => { - let option = new Option(template, template); - inspectDropdown.add(option); - }); - }) - .catch((error) => console.error("Error loading inspect templates:", error)); - - fetch("/get_deidentify_templates") - .then((response) => response.json()) - .then((data) => { - const deidentifyDropdown = document.getElementById( - "deidentify-template-dropdown" - ); - data.forEach((template) => { - let option = new Option(template, template); - deidentifyDropdown.add(option); - }); - }) - .catch((error) => - console.error("Error loading deidentify templates:", error) - ); -} - -function loadPreviousMessages() { - fetch("/get_chat_history") - .then((response) => response.json()) - .then((data) => { - const { history_messages } = data; - var chatEl = document.getElementById("chat"); - history_messages.map(({ prompt, message }) => { - var promptEl = Object.assign(document.createElement("p"), { - classList: ["previous_message"], - }); - promptEl.textContent = `${prompt} : ${message}`; - chatEl.appendChild(promptEl); - }); - }) - .catch((error) => - console.error("Error getting previous chat messages:", error) - ); -} diff --git a/styles.css b/styles.css deleted file mode 100644 index d2bf11352..000000000 --- a/styles.css +++ /dev/null @@ -1,209 +0,0 @@ -/* Copyright 2024 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -/* styles.css */ -body { - font-family: Arial, sans-serif; - text-align: center; - - max-width: 2000px; - margin: 0 auto; -} - -.content-container { - display: flex; - max-width: 2000px; - margin: 0 auto; -} - -#chat-and-form { - width: 70%; -} - -#chat { - height: 80vh; - padding: 10px; - margin-top: 20px; - margin-bottom: 10px; - display: block; - - text-align: left; - overflow-y: scroll; - - border: 1px solid #222; - border-radius: 2px; -} - -#chat p { - white-space: pre-wrap; - - padding: 8px; - border-radius: 4px; -} - -#chat p:first-child { - margin-top: 0; -} - -p.prompt { - background-color: rgba(50, 255, 0, 0.1); -} - -p.prompt::before { - content: 'Prompt: '; - font-weight: bold; -} - -p.instruct{ - background-color: rgba(255, 255, 0, 0.2); -} - - -p.instruct::before { - content: 'Instructions: '; - font-weight: bold; -} - -p.previous_message { - background-color: rgba(73, 240, 240, 0.753); -} - -p.previous_message::before { - content: 'Previous message: '; - font-weight: bold; -} - -p.response { - background-color: rgba(0, 134, 255, 0.1); -} - -p.response::before { - content: 'Response: '; - font-weight: bold; -} - -p.warning { - background-color: rgba(255, 229, 100, 0.1); -} - -p.error::before { - content: 'Warning: '; - font-weight: bold; -} - -p.error { - background-color: rgba(255, 0, 0, 0.1); -} - -p.error::before { - content: 'Error: '; - font-weight: bold; -} - -form { - display: flex; - width: 100%; -} - -#prompt { - flex: 1; - padding: 10px; - margin-right: 10px; - border-radius: 2px; -} - -input[type="submit"] { - padding: 10px 20px; - background-color: rgb(76,175,80); - color: white; - border: none; - border-radius: 2px; -} - -input[type="submit"]:disabled { - background-color: rgba(0, 0, 0, 0.2); -} - -/* Filter Section Styles */ -#filter-section { - width: 30%; - margin-right: 2%; /* Space between filter section and the next section, adjust as needed */ - border: 1px solid #ccc; /* Optional: adds a border around the filter section */ - padding: 20px; /* Adds some space inside the filter section */ - box-shadow: 0 2px 4px rgba(0,0,0,0.1); /* Optional: adds a slight shadow for depth */ -} - -#dlp-filter-section, -#nlp-filter-section { - margin-top: 20px; -} - -select { - width: 100%; - padding: 10px; - margin-top: 10px; - border: 1px solid #ccc; /* Adds a border to the dropdowns */ - border-radius: 4px; /* Rounds the corners of the dropdowns */ -} - -label { - display: block; - margin-top: 15px; - font-weight: bold; -} - -input[type="checkbox"] { - margin-right: 10px; /* Space between checkbox and its label */ -} - -#nlp-range { - width: 100%; - margin-top: 10px; /* Adjusted for consistency */ -} - -input[type="range"] { - -webkit-appearance: none; - appearance: none; - width: 100%; - height: 8px; - background: #ddd; - outline: none; - opacity: 1; - -webkit-transition: .2s; - transition: opacity .2s; -} - -input[type="range"]::-webkit-slider-thumb { - -webkit-appearance: none; - appearance: none; - width: 25px; - height: 25px; - background: var(--thumb-color); - cursor: pointer; -} - -input[type="range"]::-moz-range-thumb { - width: 25px; - height: 25px; - background: var(--thumb-color); - cursor: pointer; -} - -input[type="range"]::-ms-thumb { - width: 25px; - height: 25px; - background: var(--thumb-color); - cursor: pointer; -} \ No newline at end of file From 0d853eac45cf0f07d4e5ed351822ca918d8d3379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Thu, 1 Aug 2024 14:53:52 +0000 Subject: [PATCH 11/46] Ignoring test rag, to review how the rag application is working --- cloudbuild.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 36b4b7600..902527c82 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -268,7 +268,7 @@ steps: kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- jupyter nbconvert --to script /data/rag-kaggle-ray-sql-interactive.ipynb kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- ipython /data/rag-kaggle-ray-sql-interactive.py - python3 ./applications/rag/tests/test_rag.py "http://127.0.0.1:8081/prompt" + # python3 ./applications/rag/tests/test_rag.py "http://127.0.0.1:8081/prompt" Ignoring while the test approach is reviewed echo "pass" > /workspace/rag_prompt_result.txt allowFailure: true From 386c4379d2f811074bed926f8cab44deb1d1afee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Thu, 1 Aug 2024 16:47:48 +0000 Subject: [PATCH 12/46] ignoring unit test to review cloud build process --- cloudbuild.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 902527c82..77f76ca3a 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -254,7 +254,7 @@ steps: sleep 5s cd /workspace/applications/rag/tests - python3 test_frontend.py "127.0.0.1:8081" + # python3 test_frontend.py "127.0.0.1:8081" echo "pass" > /workspace/rag_frontend_result.txt cd /workspace/ From be1839dc6dbad7d96ac10fbbbe606b3344c96c93 Mon Sep 17 00:00:00 2001 From: German Stiven Grandas Aguirre Date: Tue, 6 Aug 2024 12:01:04 -0500 Subject: [PATCH 13/46] refactoring cloud sql connection helper --- .../application/cloud_sql/__init__.py | 15 ---- .../application/cloud_sql/cloud_sql.py | 85 ------------------- .../application/rag_langchain/rag_chain.py | 79 +++++++++-------- .../container/application/utils/__init__.py | 4 +- .../application/utils/cloud_sql_utils.py | 78 +++++++++++++++++ .../application/vector_storages/cloud_sql.py | 3 +- 6 files changed, 123 insertions(+), 141 deletions(-) delete mode 100644 applications/rag/frontend/container/application/cloud_sql/__init__.py delete mode 100644 applications/rag/frontend/container/application/cloud_sql/cloud_sql.py create mode 100644 applications/rag/frontend/container/application/utils/cloud_sql_utils.py diff --git a/applications/rag/frontend/container/application/cloud_sql/__init__.py b/applications/rag/frontend/container/application/cloud_sql/__init__.py deleted file mode 100644 index 11f30faf0..000000000 --- a/applications/rag/frontend/container/application/cloud_sql/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is required to make Python treat the subfolder as a package diff --git a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py b/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py deleted file mode 100644 index 1c0da565b..000000000 --- a/applications/rag/frontend/container/application/cloud_sql/cloud_sql.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import uuid -import logging - -from typing import List, Optional, Iterable, Any - -from langchain_core.vectorstores import VectorStore -from langchain_core.embeddings import Embeddings -from langchain_core.documents import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter - -from langchain_google_cloud_sql_pg import PostgresVectorStore - - -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") -CHUNK_SIZE = 1000 -CHUNK_OVERLAP = 10 - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - - -class CloudSQLVectorStore(VectorStore): - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ): - raise NotImplementedError - - def __init__(self, embedding_provider, engine): - self.vector_store = PostgresVectorStore.create_sync( - engine=engine, - embedding_service=embedding_provider, - table_name=VECTOR_EMBEDDINGS_TABLE_NAME, - ) - self.splitter = RecursiveCharacterTextSplitter( - chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len - ) - self.embeddings_service = embedding_provider - - # TODO implement - def add_texts( - self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any - ) -> List[str]: - try: - splits = self.splitter.split_documents(texts) - ids = [str(uuid.uuid4()) for _ in range(len(splits))] - self.vector_store.add_documents(splits, ids) - except Exception as e: - logging.info(f"Error: {e}") - raise e - - # TODO implement similarity search with cosine similarity threshold - - def similarity_search( - self, query: dict, k: int = 4, **kwargs: Any - ) -> List[Document]: - try: - - query_input = query.get("input") - query_vector = self.embeddings_service.embed_query(query_input) - docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) - return docs - - except Exception as err: - raise Exception(f"General error: {err}") diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index f290021ff..b8cf7b458 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import logging from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder @@ -22,9 +22,8 @@ from langchain_google_cloud_sql_pg import PostgresChatMessageHistory -from application.cloud_sql.cloud_sql import ( - CHAT_HISTORY_TABLE_NAME, - create_sync_postgres_engine, +from application.utils import ( + create_sync_postgres_engine ) from application.rag_langchain.huggingface_inference_model import ( HuggingFaceCustomChatModel, @@ -35,17 +34,16 @@ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) +CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") + QUESTION = "input" HISTORY = "chat_history" CONTEXT = "context" SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings -template_str = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. -Improve upon your previous answers using History, a list of messages. -Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. -Stick to the facts by basing your answers off of the Context provided. -Be brief in answering. +template_str = """Answer the question given by the user in no more than 2 sentences. +Use the provided context to improve upon your previous answers. Stick to the facts and be brief. Avoid conversational format. \n\n Context: {context} """ @@ -80,35 +78,42 @@ def clear_chat_history(session_id: str): def create_chain() -> RunnableWithMessageHistory: - model = HuggingFaceCustomChatModel() - - langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CloudSQLVectorStore(langchain_embed, engine) - - retriever = vector_store.as_retriever() - - setup_and_retrieval = RunnableParallel( - { - "context": retriever, - QUESTION: RunnableLambda(lambda d: d[QUESTION]), - HISTORY: RunnableLambda(lambda d: d[HISTORY]), - } - ) - - chain = setup_and_retrieval | prompt | model - chain_with_history = RunnableWithMessageHistory( - chain, - get_chat_history, - input_messages_key=QUESTION, - history_messages_key=HISTORY, - output_messages_key="output", - ) - return chain_with_history - + try: + model = HuggingFaceCustomChatModel() + + langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) + vector_store = CloudSQLVectorStore(langchain_embed, engine) + + retriever = vector_store.as_retriever() + + setup_and_retrieval = RunnableParallel( + { + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]), + } + ) + + chain = setup_and_retrieval | prompt | model + chain_with_history = RunnableWithMessageHistory( + chain, + get_chat_history, + input_messages_key=QUESTION, + history_messages_key=HISTORY, + output_messages_key="output", + ) + return chain_with_history + except Exception as e: + logging.info(e) + raise e def take_chat_turn( chain: RunnableWithMessageHistory, session_id: str, query_text: str ) -> str: - config = {"configurable": {"session_id": session_id}} - result = chain.invoke({"input": query_text}, config=config) - return result + try: + config = {"configurable": {"session_id": session_id}} + result = chain.invoke({"input": query_text}, config=config) + return result + except Exception as e: + logging.info(e) + raise e \ No newline at end of file diff --git a/applications/rag/frontend/container/application/utils/__init__.py b/applications/rag/frontend/container/application/utils/__init__.py index 8034550de..18ef0b7ac 100644 --- a/applications/rag/frontend/container/application/utils/__init__.py +++ b/applications/rag/frontend/container/application/utils/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from .huggingface_tgi_helper import post_request +from .cloud_sql_utils import create_sync_postgres_engine - -__all__ = ["post_request"] +__all__ = ["post_request", "create_sync_postgres_engine"] diff --git a/applications/rag/frontend/container/application/utils/cloud_sql_utils.py b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py new file mode 100644 index 000000000..d02484ef8 --- /dev/null +++ b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py @@ -0,0 +1,78 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging + +from google.cloud.sql.connector import IPTypes + +from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +ENVIRONMENT = os.environ.get("ENVIRONMENT") + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID") +GCP_CLOUD_SQL_REGION = os.environ.get("CLOUDSQL_INSTANCE_REGION") +GCP_CLOUD_SQL_INSTANCE = os.environ.get("CLOUDSQL_INSTANCE") + +DB_NAME = os.environ.get("DB_NAME", "pgvector-database") +VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") +CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") + +VECTOR_DIMENSION = os.environ.get("VECTOR_DIMENSION", 384) + +try: + db_username_file = open("/etc/secret-volume/username", "r") + DB_USER = db_username_file.read() + db_username_file.close() + + db_password_file = open("/etc/secret-volume/password", "r") + DB_PASS = db_password_file.read() + db_password_file.close() +except: + DB_USER = os.environ.get("DB_USERNAME", "postgres") + DB_PASS = os.environ.get("DB_PASS", "postgres") + + +def create_sync_postgres_engine(): + engine = PostgresEngine.from_instance( + project_id=GCP_PROJECT_ID, + region=GCP_CLOUD_SQL_REGION, + instance=GCP_CLOUD_SQL_INSTANCE, + database=DB_NAME, + user=DB_USER, + password=DB_PASS, + ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE, + ) + try: + engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) + engine.init_vectorstore_table( + VECTOR_EMBEDDINGS_TABLE_NAME, + vector_size=VECTOR_DIMENSION, + overwrite_existing=False, + ) + except Exception as e: + logging.info(f"Error: {e}") + + return engine \ No newline at end of file diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py index 11e08ed28..3fd6f2cde 100644 --- a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -78,9 +78,8 @@ def similarity_search( query_input = query.get("input") query_vector = self.embeddings_service.embed_query(query_input) - docs = self.vector_store.similarity_search_by_vector(query_vector, k=4) + docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) return docs except Exception as err: raise Exception(f"General error: {err}") - From 35f67e446892b0972b810613e30cec8d18bca5b2 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 19:41:09 +0200 Subject: [PATCH 14/46] Change TPU Metrics Source for Autoscaling (#770) first commit --- .../hpa.jetstream.yaml.tftpl | 6 +++++- .../templates/prometheus-adapter/values.yaml.tftpl | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/jetstream-maxtext-deployment/templates/custom-metrics-stackdriver-adapter/hpa.jetstream.yaml.tftpl b/modules/jetstream-maxtext-deployment/templates/custom-metrics-stackdriver-adapter/hpa.jetstream.yaml.tftpl index b70218558..414cc2432 100644 --- a/modules/jetstream-maxtext-deployment/templates/custom-metrics-stackdriver-adapter/hpa.jetstream.yaml.tftpl +++ b/modules/jetstream-maxtext-deployment/templates/custom-metrics-stackdriver-adapter/hpa.jetstream.yaml.tftpl @@ -24,7 +24,11 @@ spec: - type: External external: metric: - name: kubernetes.io|node|accelerator|${rule.target_query} + name: prometheus.googleapis.com|${rule.target_query}|gauge + selector: + matchLabels: + metric.labels.container: jetstream-http + metric.labels.exported_namespace: default target: type: AverageValue averageValue: ${rule.average_value_target} diff --git a/modules/jetstream-maxtext-deployment/templates/prometheus-adapter/values.yaml.tftpl b/modules/jetstream-maxtext-deployment/templates/prometheus-adapter/values.yaml.tftpl index a07058dee..b1091fe9f 100644 --- a/modules/jetstream-maxtext-deployment/templates/prometheus-adapter/values.yaml.tftpl +++ b/modules/jetstream-maxtext-deployment/templates/prometheus-adapter/values.yaml.tftpl @@ -29,10 +29,10 @@ rules: matches: "" as: "jetstream_slots_used_percentage" metricsQuery: avg(<<.Series>>{<<.LabelMatchers>>,cluster="${cluster_name}"}) - - seriesQuery: 'kubernetes_io:node_accelerator_memory_used' + - seriesQuery: 'memory_used' resources: template: <<.Resource>> name: matches: "" as: "memory_used_percentage" - metricsQuery: avg(kubernetes_io:node_accelerator_memory_used{cluster_name="${cluster_name}"}) / avg(kubernetes_io:node_accelerator_memory_total{cluster_name="${cluster_name}"}) \ No newline at end of file + metricsQuery: avg(memory_used{cluster="${cluster_name}",exported_namespace="default",container="jetstream-http"}) / avg(memory_total{cluster="${cluster_name}",exported_namespace="default",container="jetstream-http"}) \ No newline at end of file From 00220530819a0a7159fe1c3d772e7eb55bef1017 Mon Sep 17 00:00:00 2001 From: genlu2011 Date: Thu, 15 Aug 2024 10:20:27 -0700 Subject: [PATCH 15/46] Refactor: move workload identity service account out of kuberay-operator (#769) * Refactor: create module for workload identity service account Change-Id: I29e985e77a1ff2d5f4a8d9493c1e65907c89c100 * fix: add todo Change-Id: I3357e8f9dd16c7958dff0f0cf0f990fce980f474 --------- Co-authored-by: Gen Lu --- applications/rag/main.tf | 64 ++++++++++++++++----------- applications/ray/main.tf | 35 ++++++++++----- modules/kuberay-operator/kuberay.tf | 27 ----------- modules/kuberay-operator/variables.tf | 12 ----- 4 files changed, 63 insertions(+), 75 deletions(-) diff --git a/applications/rag/main.tf b/applications/rag/main.tf index 6ff27fe8f..0bb597d06 100644 --- a/applications/rag/main.tf +++ b/applications/rag/main.tf @@ -153,15 +153,13 @@ module "namespace" { } module "kuberay-operator" { - source = "../../modules/kuberay-operator" - providers = { helm = helm.rag, kubernetes = kubernetes.rag } - name = "kuberay-operator" - project_id = var.project_id - create_namespace = true - namespace = local.kubernetes_namespace - google_service_account = local.ray_service_account - create_service_account = var.create_ray_service_account - autopilot_cluster = local.enable_autopilot + source = "../../modules/kuberay-operator" + providers = { helm = helm.rag, kubernetes = kubernetes.rag } + name = "kuberay-operator" + project_id = var.project_id + create_namespace = true + namespace = local.kubernetes_namespace + autopilot_cluster = local.enable_autopilot } module "gcs" { @@ -225,6 +223,32 @@ module "kuberay-logging" { depends_on = [module.namespace] } +module "kuberay-workload-identity" { + providers = { kubernetes = kubernetes.rag } + source = "terraform-google-modules/kubernetes-engine/google//modules/workload-identity" + version = "30.0.0" # Pinning to a previous version as current version (30.1.0) showed inconsitent behaviour with workload identity service accounts + use_existing_gcp_sa = !var.create_ray_service_account + name = local.ray_service_account + namespace = local.kubernetes_namespace + project_id = var.project_id + roles = ["roles/cloudsql.client", "roles/monitoring.viewer"] + automount_service_account_token = true + depends_on = [module.namespace] +} + +module "kuberay-monitoring" { + source = "../../modules/kuberay-monitoring" + providers = { helm = helm.rag, kubernetes = kubernetes.rag } + project_id = var.project_id + autopilot_cluster = local.enable_autopilot + namespace = local.kubernetes_namespace + create_namespace = true + enable_grafana_on_ray_dashboard = var.enable_grafana_on_ray_dashboard + k8s_service_account = local.ray_service_account + //TODO(genlu): remove the module.kuberay-operator after migrated using ray addon. + depends_on = [module.namespace, module.kuberay-operator, module.kuberay-workload-identity] +} + module "kuberay-cluster" { source = "../../modules/kuberay-cluster" providers = { helm = helm.rag, kubernetes = kubernetes.rag } @@ -233,16 +257,17 @@ module "kuberay-cluster" { enable_gpu = true gcs_bucket = var.gcs_bucket autopilot_cluster = local.enable_autopilot - db_secret_name = module.cloudsql.db_secret_name cloudsql_instance_name = local.cloudsql_instance db_region = local.cloudsql_instance_region google_service_account = local.ray_service_account - grafana_host = module.kuberay-monitoring.grafana_uri disable_network_policy = var.disable_ray_cluster_network_policy - depends_on = [module.kuberay-operator] use_custom_image = true additional_labels = var.additional_labels + # Implicit dependency + db_secret_name = module.cloudsql.db_secret_name + grafana_host = module.kuberay-monitoring.grafana_uri + # IAP Auth parameters add_auth = var.ray_dashboard_add_auth create_brand = var.create_brand @@ -256,19 +281,8 @@ module "kuberay-cluster" { k8s_backend_service_port = var.ray_dashboard_k8s_backend_service_port domain = var.ray_dashboard_domain members_allowlist = var.ray_dashboard_members_allowlist != "" ? split(",", var.ray_dashboard_members_allowlist) : [] -} - -module "kuberay-monitoring" { - source = "../../modules/kuberay-monitoring" - providers = { helm = helm.rag, kubernetes = kubernetes.rag } - project_id = var.project_id - autopilot_cluster = local.enable_autopilot - namespace = local.kubernetes_namespace - create_namespace = true - enable_grafana_on_ray_dashboard = var.enable_grafana_on_ray_dashboard - k8s_service_account = local.ray_service_account - # TODO(umeshkumhar): remove kuberay-operator depends, figure out service account dependency - depends_on = [module.namespace, module.kuberay-operator] + //TODO(genlu): remove the module.kuberay-operator after migrated using ray addon. + depends_on = [module.gcs, module.kuberay-operator, module.kuberay-workload-identity] } module "inference-server" { diff --git a/applications/ray/main.tf b/applications/ray/main.tf index 207807532..8f4a5ecef 100644 --- a/applications/ray/main.tf +++ b/applications/ray/main.tf @@ -134,16 +134,27 @@ module "namespace" { namespace = local.kubernetes_namespace } +module "kuberay-workload-identity" { + providers = { kubernetes = kubernetes.ray } + source = "terraform-google-modules/kubernetes-engine/google//modules/workload-identity" + version = "30.0.0" # Pinning to a previous version as current version (30.1.0) showed inconsitent behaviour with workload identity service accounts + use_existing_gcp_sa = !var.create_service_account + name = local.workload_identity_service_account + namespace = local.kubernetes_namespace + project_id = var.project_id + roles = ["roles/cloudsql.client", "roles/monitoring.viewer"] + automount_service_account_token = true + depends_on = [module.namespace] +} + module "kuberay-operator" { - source = "../../modules/kuberay-operator" - providers = { helm = helm.ray, kubernetes = kubernetes.ray } - name = "kuberay-operator" - create_namespace = true - namespace = local.kubernetes_namespace - project_id = var.project_id - autopilot_cluster = local.enable_autopilot - google_service_account = local.workload_identity_service_account - create_service_account = var.create_service_account + source = "../../modules/kuberay-operator" + providers = { helm = helm.ray, kubernetes = kubernetes.ray } + name = "kuberay-operator" + create_namespace = true + namespace = local.kubernetes_namespace + project_id = var.project_id + autopilot_cluster = local.enable_autopilot } module "kuberay-logging" { @@ -164,7 +175,8 @@ module "kuberay-monitoring" { create_namespace = true enable_grafana_on_ray_dashboard = var.enable_grafana_on_ray_dashboard k8s_service_account = local.workload_identity_service_account - depends_on = [module.kuberay-operator] + //TODO(genlu): remove the module.kuberay-operator after migrated using ray addon. + depends_on = [module.kuberay-workload-identity, module.kuberay-operator] } module "gcs" { @@ -204,7 +216,8 @@ module "kuberay-cluster" { k8s_backend_service_port = var.ray_dashboard_k8s_backend_service_port domain = var.ray_dashboard_domain members_allowlist = var.ray_dashboard_members_allowlist != "" ? split(",", var.ray_dashboard_members_allowlist) : [] - depends_on = [module.gcs, module.kuberay-operator] + //TODO(genlu): remove the module.kuberay-operator after migrated using ray addon. + depends_on = [module.gcs, module.kuberay-operator, module.kuberay-workload-identity] } diff --git a/modules/kuberay-operator/kuberay.tf b/modules/kuberay-operator/kuberay.tf index 2315a3149..e754bdf89 100644 --- a/modules/kuberay-operator/kuberay.tf +++ b/modules/kuberay-operator/kuberay.tf @@ -23,33 +23,6 @@ resource "helm_release" "kuberay-operator" { create_namespace = var.create_namespace } -module "kuberay-workload-identity" { - source = "terraform-google-modules/kubernetes-engine/google//modules/workload-identity" - version = "30.0.0" # Pinning to a previous version as current version (30.1.0) showed inconsitent behaviour with workload identity service accounts - use_existing_gcp_sa = !var.create_service_account - name = var.google_service_account - namespace = var.namespace - project_id = var.project_id - roles = ["roles/cloudsql.client", "roles/monitoring.viewer"] - - automount_service_account_token = true - - depends_on = [helm_release.kuberay-operator] -} - -resource "kubernetes_secret_v1" "service_account_token" { - metadata { - name = "kuberay-sa-token" - namespace = var.namespace - annotations = { - "kubernetes.io/service-account.name" = var.google_service_account - } - } - type = "kubernetes.io/service-account-token" - - depends_on = [module.kuberay-workload-identity] -} - # Grant access to batchv1/Jobs to kuberay-operator since the kuberay-operator role is missing some permissions. # See https://github.com/ray-project/kuberay/issues/1706 for more details. # TODO: remove this role binding once the kuberay-operator helm chart is upgraded to v1.1 diff --git a/modules/kuberay-operator/variables.tf b/modules/kuberay-operator/variables.tf index 7b984b4f7..a977c27de 100644 --- a/modules/kuberay-operator/variables.tf +++ b/modules/kuberay-operator/variables.tf @@ -34,15 +34,3 @@ variable "create_namespace" { variable "autopilot_cluster" { type = bool } - -variable "google_service_account" { - type = string - description = "Google service account name" - default = "kuberay-gcp-sa" -} - -variable "create_service_account" { - type = bool - description = "Creates a google service account & k8s service account & configures workload identity" - default = true -} From 48f655b8d833788d0890e54b7ae808c68cfe5031 Mon Sep 17 00:00:00 2001 From: German Grandas Date: Tue, 20 Aug 2024 15:13:28 -0500 Subject: [PATCH 16/46] updating branch --- applications/rag/main.tf | 18 +- applications/rag/podmonitoring.yaml | 11 + charts/gmp-engine/Chart.yaml | 28 +++ .../charts/gmp-frontend}/Chart.yaml | 2 +- .../gmp-frontend}/templates/deployment.yaml | 4 +- .../gmp-frontend}/templates/service.yaml | 0 .../charts/gmp-frontend/values.yaml | 17 ++ .../gmp-engine/templates/podmonitoring.yaml | 0 charts/gmp-engine/values.yaml | 10 + modules/inference-service/README | 2 + modules/inference-service/main.tf | 187 +++++++++++++++ modules/inference-service/outputs.tf | 28 +++ modules/inference-service/variables.tf | 31 +++ modules/inference-service/versions.tf | 27 +++ .../charts/gmp-engine/values.yaml | 30 --- modules/kuberay-monitoring/gmpvalues.yaml | 9 + modules/kuberay-monitoring/main.tf | 18 +- tutorials-and-examples/hf-tgi/README.md | 18 +- tutorials-and-examples/hf-tgi/main.tf | 224 ++++-------------- tutorials-and-examples/hf-tgi/outputs.tf | 6 +- .../hf-tgi/podmonitoring.yaml | 6 + tutorials-and-examples/hf-tgi/variables.tf | 20 +- tutorials-and-examples/hf-tgi/versions.tf | 4 + 23 files changed, 467 insertions(+), 233 deletions(-) create mode 100644 applications/rag/podmonitoring.yaml create mode 100644 charts/gmp-engine/Chart.yaml rename {modules/kuberay-monitoring/charts/gmp-engine => charts/gmp-engine/charts/gmp-frontend}/Chart.yaml (98%) rename {modules/kuberay-monitoring/charts/gmp-engine => charts/gmp-engine/charts/gmp-frontend}/templates/deployment.yaml (89%) rename {modules/kuberay-monitoring/charts/gmp-engine => charts/gmp-engine/charts/gmp-frontend}/templates/service.yaml (100%) create mode 100644 charts/gmp-engine/charts/gmp-frontend/values.yaml rename {modules/kuberay-monitoring/charts => charts}/gmp-engine/templates/podmonitoring.yaml (100%) create mode 100644 charts/gmp-engine/values.yaml create mode 100644 modules/inference-service/README create mode 100644 modules/inference-service/main.tf create mode 100644 modules/inference-service/outputs.tf create mode 100644 modules/inference-service/variables.tf create mode 100644 modules/inference-service/versions.tf delete mode 100644 modules/kuberay-monitoring/charts/gmp-engine/values.yaml create mode 100644 modules/kuberay-monitoring/gmpvalues.yaml create mode 100644 tutorials-and-examples/hf-tgi/podmonitoring.yaml diff --git a/applications/rag/main.tf b/applications/rag/main.tf index 0bb597d06..45d319bd6 100644 --- a/applications/rag/main.tf +++ b/applications/rag/main.tf @@ -286,7 +286,7 @@ module "kuberay-cluster" { } module "inference-server" { - source = "../../tutorials-and-examples/hf-tgi" + source = "../../modules/inference-service" providers = { kubernetes = kubernetes.rag } namespace = local.kubernetes_namespace additional_labels = var.additional_labels @@ -324,4 +324,18 @@ module "frontend" { domain = var.frontend_domain members_allowlist = var.frontend_members_allowlist != "" ? split(",", var.frontend_members_allowlist) : [] depends_on = [module.namespace] -} \ No newline at end of file +} + +resource "helm_release" "gmp-apps" { + name = "gmp-apps" + provider = helm.rag + chart = "../../charts/gmp-engine/" + namespace = local.kubernetes_namespace + # Timeout is increased to guarantee sufficient scale-up time for Autopilot nodes. + timeout = 1200 + depends_on = [module.inference-server, module.frontend] + values = [ + "${file("${path.module}/podmonitoring.yaml")}" + ] +} + diff --git a/applications/rag/podmonitoring.yaml b/applications/rag/podmonitoring.yaml new file mode 100644 index 000000000..d9e7bf761 --- /dev/null +++ b/applications/rag/podmonitoring.yaml @@ -0,0 +1,11 @@ +podMonitoring: +- name: mistral-7b-instruct + selector: + app: mistral-7b-instruct + port: metrics + interval: 30s +- name: rag-frontend + selector: + app: rag-frontend + port: metrics + interval: 30s diff --git a/charts/gmp-engine/Chart.yaml b/charts/gmp-engine/Chart.yaml new file mode 100644 index 000000000..2d2ea0411 --- /dev/null +++ b/charts/gmp-engine/Chart.yaml @@ -0,0 +1,28 @@ +apiVersion: v2 +name: gmp-engine +description: A Helm chart for Kubernetes + +# A chart can be either an 'application' or a 'library' chart. +# +# Application charts are a collection of templates that can be packaged into versioned archives +# to be deployed. +# +# Library charts provide useful utilities or functions for the chart developer. They're included as +# a dependency of application charts to inject those utilities and functions into the rendering +# pipeline. Library charts do not define any templates and therefore cannot be deployed. +type: application + +# This is the chart version. This version number should be incremented each time you make changes +# to the chart and its templates, including the app version. +# Versions are expected to follow Semantic Versioning (https://semver.org/) +version: 0.1.0 + +# This is the version number of the application being deployed. This version number should be +# incremented each time you make changes to the application. Versions are not expected to +# follow Semantic Versioning. They should reflect the version the application is using. +# It is recommended to use it with quotes. +appVersion: "1.0.0" + +dependencies: +- name: gmp-frontend + condition: gmp-frontend.enabled diff --git a/modules/kuberay-monitoring/charts/gmp-engine/Chart.yaml b/charts/gmp-engine/charts/gmp-frontend/Chart.yaml similarity index 98% rename from modules/kuberay-monitoring/charts/gmp-engine/Chart.yaml rename to charts/gmp-engine/charts/gmp-frontend/Chart.yaml index 9658d5358..f4442a9f9 100644 --- a/modules/kuberay-monitoring/charts/gmp-engine/Chart.yaml +++ b/charts/gmp-engine/charts/gmp-frontend/Chart.yaml @@ -1,5 +1,5 @@ apiVersion: v2 -name: gmp-engine +name: gmp-frontend description: A Helm chart for Kubernetes # A chart can be either an 'application' or a 'library' chart. diff --git a/modules/kuberay-monitoring/charts/gmp-engine/templates/deployment.yaml b/charts/gmp-engine/charts/gmp-frontend/templates/deployment.yaml similarity index 89% rename from modules/kuberay-monitoring/charts/gmp-engine/templates/deployment.yaml rename to charts/gmp-engine/charts/gmp-frontend/templates/deployment.yaml index a354907f5..83e455afb 100644 --- a/modules/kuberay-monitoring/charts/gmp-engine/templates/deployment.yaml +++ b/charts/gmp-engine/charts/gmp-frontend/templates/deployment.yaml @@ -28,14 +28,14 @@ spec: labels: app: {{ .Values.name }} spec: - serviceAccountName: {{ .Values.serviceAccount }} + serviceAccountName: {{ required "serviceAccount is required!" .Values.serviceAccount }} containers: - name: {{ .Values.name }} image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" imagePullPolicy: {{ .Values.image.pullPolicy }} args: - "--web.listen-address=:9090" - - "--query.project-id={{ .Values.projectID }}" + - "--query.project-id={{ required "projectID is required!" .Values.projectID }}" ports: - name: web containerPort: 9090 diff --git a/modules/kuberay-monitoring/charts/gmp-engine/templates/service.yaml b/charts/gmp-engine/charts/gmp-frontend/templates/service.yaml similarity index 100% rename from modules/kuberay-monitoring/charts/gmp-engine/templates/service.yaml rename to charts/gmp-engine/charts/gmp-frontend/templates/service.yaml diff --git a/charts/gmp-engine/charts/gmp-frontend/values.yaml b/charts/gmp-engine/charts/gmp-frontend/values.yaml new file mode 100644 index 000000000..2aee7cbb8 --- /dev/null +++ b/charts/gmp-engine/charts/gmp-frontend/values.yaml @@ -0,0 +1,17 @@ +# Default values for gmp-frontend. +# This is a YAML-formatted file. +# Declare variables to be passed into your templates. + +name: "gmp-frontend" +projectID: "" +serviceAccount: "" + +image: + repository: gke.gcr.io/prometheus-engine/frontend + pullPolicy: IfNotPresent + tag: "v0.5.0-gke.0" + +replicaCount: 2 + +cpu: "1m" +memory: "5Mi" diff --git a/modules/kuberay-monitoring/charts/gmp-engine/templates/podmonitoring.yaml b/charts/gmp-engine/templates/podmonitoring.yaml similarity index 100% rename from modules/kuberay-monitoring/charts/gmp-engine/templates/podmonitoring.yaml rename to charts/gmp-engine/templates/podmonitoring.yaml diff --git a/charts/gmp-engine/values.yaml b/charts/gmp-engine/values.yaml new file mode 100644 index 000000000..74441006a --- /dev/null +++ b/charts/gmp-engine/values.yaml @@ -0,0 +1,10 @@ +# Default values for iap_jupyter. +# This is a YAML-formatted file. +# Declare variables to be passed into your templates. + +podMonitoring: [] + +gmp-frontend: + enabled: false + projectID: "" + serviceAccount: "" diff --git a/modules/inference-service/README b/modules/inference-service/README new file mode 100644 index 000000000..74d59bb83 --- /dev/null +++ b/modules/inference-service/README @@ -0,0 +1,2 @@ +# Inference Service +This module is currently designed specifically for the Mistral-7B-Instruct-v0.1 model. Future developments will expand the module to support the creation of customized models more broadly. diff --git a/modules/inference-service/main.tf b/modules/inference-service/main.tf new file mode 100644 index 000000000..91e558369 --- /dev/null +++ b/modules/inference-service/main.tf @@ -0,0 +1,187 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +locals { + additional_labels = length(var.additional_labels) == 0 ? {} : tomap({ + for item in split(",", var.additional_labels) : + split("=", item)[0] => split("=", item)[1] + }) +} + +resource "kubernetes_service" "inference_service" { + metadata { + name = "mistral-7b-instruct-service" + labels = { + app = "mistral-7b-instruct" + } + namespace = var.namespace + annotations = { + "cloud.google.com/load-balancer-type" = "Internal" + "cloud.google.com/neg" = "{\"ingress\":true}" + } + } + spec { + selector = { + app = "mistral-7b-instruct" + } + session_affinity = "ClientIP" + port { + protocol = "TCP" + port = 80 + target_port = 8080 + } + + type = "LoadBalancer" + } +} + +resource "kubernetes_deployment" "inference_deployment" { + timeouts { + create = "30m" + } + metadata { + name = "mistral-7b-instruct" + namespace = var.namespace + labels = merge({ + app = "mistral-7b-instruct" + }, local.additional_labels) + } + + spec { + # It takes more than 10m for the deployment to be ready on Autopilot cluster + # Set the progress deadline to 30m to avoid the deployment controller + # considering the deployment to be failed + progress_deadline_seconds = 1800 + replicas = 1 + + selector { + match_labels = merge({ + app = "mistral-7b-instruct" + }, local.additional_labels) + } + + template { + metadata { + labels = merge({ + app = "mistral-7b-instruct" + }, local.additional_labels) + } + + spec { + init_container { + name = "download-model" + image = "google/cloud-sdk:473.0.0-alpine" + command = ["gsutil", "cp", "-r", "gs://vertex-model-garden-public-us/mistralai/Mistral-7B-Instruct-v0.1/", "/model-data/"] + volume_mount { + mount_path = "/model-data" + name = "model-storage" + } + } + container { + image = "ghcr.io/huggingface/text-generation-inference:1.1.0" + name = "mistral-7b-instruct" + + port { + name = "metrics" + container_port = 8080 + protocol = "TCP" + } + + args = ["--model-id", "$(MODEL_ID)"] + + env { + name = "MODEL_ID" + value = "/model/Mistral-7B-Instruct-v0.1" + } + + env { + name = "NUM_SHARD" + value = "2" + } + + env { + name = "PORT" + value = "8080" + } + + resources { + limits = { + "nvidia.com/gpu" = "2" + } + requests = { + # Sufficient storage to fit the Mistral-7B-Instruct-v0.1 model + "ephemeral-storage" = "20Gi" + "nvidia.com/gpu" = "2" + } + } + + volume_mount { + mount_path = "/dev/shm" + name = "dshm" + } + + volume_mount { + mount_path = "/data" + name = "data" + } + + volume_mount { + mount_path = "/model" + name = "model-storage" + read_only = "true" + } + + #liveness_probe { + #http_get { + #path = "/" + #port = 8080 + + #http_header { + #name = "X-Custom-Header" + #value = "Awesome" + #} + #} + + #initial_delay_seconds = 3 + #period_seconds = 3 + #} + } + + volume { + name = "dshm" + empty_dir { + medium = "Memory" + } + } + + volume { + name = "data" + empty_dir {} + } + + volume { + name = "model-storage" + empty_dir {} + } + + node_selector = merge({ + "cloud.google.com/gke-accelerator" = "nvidia-l4" + }, var.autopilot_cluster ? { + "cloud.google.com/gke-ephemeral-storage-local-ssd" = "true" + "cloud.google.com/compute-class" = "Accelerator" + } : {}) + } + } + } +} diff --git a/modules/inference-service/outputs.tf b/modules/inference-service/outputs.tf new file mode 100644 index 000000000..7078bac0d --- /dev/null +++ b/modules/inference-service/outputs.tf @@ -0,0 +1,28 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +output "inference_service_name" { + description = "Name of model inference service" + value = kubernetes_service.inference_service.metadata[0].name +} + +output "inference_service_namespace" { + description = "Namespace of model inference service" + value = kubernetes_service.inference_service.metadata[0].namespace +} + +output "inference_service_endpoint" { + description = "Endpoint of model inference service" + value = kubernetes_service.inference_service.status != null ? (kubernetes_service.inference_service.status[0].load_balancer != null ? "${kubernetes_service.inference_service.status[0].load_balancer[0].ingress[0].ip}" : "") : "" +} diff --git a/modules/inference-service/variables.tf b/modules/inference-service/variables.tf new file mode 100644 index 000000000..1a826a789 --- /dev/null +++ b/modules/inference-service/variables.tf @@ -0,0 +1,31 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +variable "namespace" { + type = string + description = "Kubernetes namespace where resources are deployed" + default = "default" +} + +variable "additional_labels" { + // string is used instead of map(string) since blueprint metadata does not support maps. + type = string + description = "Additional labels to add to Kubernetes resources." + default = "" +} + +variable "autopilot_cluster" { + type = bool + default = false +} diff --git a/modules/inference-service/versions.tf b/modules/inference-service/versions.tf new file mode 100644 index 000000000..3a6dc225f --- /dev/null +++ b/modules/inference-service/versions.tf @@ -0,0 +1,27 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +terraform { + required_providers { + google = { + source = "hashicorp/google" + } + google-beta = { + source = "hashicorp/google-beta" + } + kubernetes = { + source = "hashicorp/kubernetes" + } + } +} diff --git a/modules/kuberay-monitoring/charts/gmp-engine/values.yaml b/modules/kuberay-monitoring/charts/gmp-engine/values.yaml deleted file mode 100644 index de4935494..000000000 --- a/modules/kuberay-monitoring/charts/gmp-engine/values.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Default values for iap_jupyter. -# This is a YAML-formatted file. -# Declare variables to be passed into your templates. - -name: "gmp-frontend" -projectID: "gcp-project-id" -serviceAccount: "default" - -image: - repository: gke.gcr.io/prometheus-engine/frontend - pullPolicy: IfNotPresent - tag: "v0.5.0-gke.0" - -replicaCount: 2 - -cpu: "1m" -memory: "5Mi" - -podMonitoring: - - name: ray-monitoring - selector: - ray.io/is-ray-node: "yes" - port: metrics - interval: 30s - - name: mistral-7b-instruct - selector: - app: mistral-7b-instruct - port: metrics - interval: 30s - diff --git a/modules/kuberay-monitoring/gmpvalues.yaml b/modules/kuberay-monitoring/gmpvalues.yaml new file mode 100644 index 000000000..7ff597b63 --- /dev/null +++ b/modules/kuberay-monitoring/gmpvalues.yaml @@ -0,0 +1,9 @@ +podMonitoring: +- name: ray-monitoring + selector: + ray.io/is-ray-node: "yes" + port: metrics + interval: 30s + +gmp-frontend: + enabled: true diff --git a/modules/kuberay-monitoring/main.tf b/modules/kuberay-monitoring/main.tf index d6ccec488..91d31f268 100644 --- a/modules/kuberay-monitoring/main.tf +++ b/modules/kuberay-monitoring/main.tf @@ -12,29 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Temporary workaround to ensure the GMP webhook is installed before applying PodMonitorings. +# Temporary workaround to ensure the GMP webhook is installed before applying PodMonitorings. +# After migrated to use ray add-on, this can be removed. resource "time_sleep" "wait_for_gmp_operator" { create_duration = "60s" } # google managed prometheus engine -resource "helm_release" "gmp-engine" { - name = "gmp-engine" - chart = "${path.module}/charts/gmp-engine/" +resource "helm_release" "gmp-ray-monitoring" { + name = "gmp-ray-monitoring" + chart = "${path.module}/../../charts/gmp-engine/" namespace = var.namespace create_namespace = var.create_namespace # Timeout is increased to guarantee sufficient scale-up time for Autopilot nodes. timeout = 1200 + values = [ + "${file("${path.module}/gmpvalues.yaml")}" + ] set { - name = "projectID" + name = "gmp-frontend.projectID" value = var.project_id } - set { - name = "serviceAccount" + name = "gmp-frontend.serviceAccount" value = var.k8s_service_account } - depends_on = [time_sleep.wait_for_gmp_operator] } diff --git a/tutorials-and-examples/hf-tgi/README.md b/tutorials-and-examples/hf-tgi/README.md index 909af994c..a51f0ab69 100644 --- a/tutorials-and-examples/hf-tgi/README.md +++ b/tutorials-and-examples/hf-tgi/README.md @@ -25,18 +25,13 @@ gcloud container node-pools create g2-standard-24 --cluster l4-demo \ --num-nodes=1 --min-nodes=1 --max-nodes=2 \ --node-locations $REGION-a,$REGION-b --region $REGION ``` -4. Apply job yaml: `kubectl apply -f mistral-deploy.yaml` +4. Provision the job and enable gathering metrics: `terrafrom apply` 5. Make sure app started ok: `kubectl logs -l app=mistral-7b-instruct` -6. Set up managed metrics collection to monarch `kubectl apply -f podmonitoring.yaml` -7. \[optional\] set up target status so that kubectl can display pod monitoring objects (kubectl describe PodMonitoring [object name]). -``` -kubectl apply -f targetstatus.yaml -``` -8. Set up port forward +6. Set up port forward ``` kubectl port-forward deployment/mistral-7b-instruct 8080:8080 & ``` -9. Try a few prompts: +7. Try a few prompts: ``` export USER_PROMPT="How to deploy a container on K8s?" ``` @@ -50,4 +45,9 @@ curl 127.0.0.1:8080/generate -X POST \ } EOF ``` -10. Look at `/metrics` endpoint of the service. Go to cloud monitoring and search for one of those metrics. For example, `tgi_request_count` or `tgi_batch_inference_count`. Those metrics should show up if you search for them in PromQL. +8. Look at `/metrics` endpoint of the service. Go to cloud monitoring and search for one of those metrics. For example, `tgi_request_count` or `tgi_batch_inference_count`. Those metrics should show up if you search for them in PromQL. + +9. Clean up the cluster +``` +gcloud container clusters delete l4-demo --location ${REGION} +``` \ No newline at end of file diff --git a/tutorials-and-examples/hf-tgi/main.tf b/tutorials-and-examples/hf-tgi/main.tf index acc999275..55e0757d9 100644 --- a/tutorials-and-examples/hf-tgi/main.tf +++ b/tutorials-and-examples/hf-tgi/main.tf @@ -12,190 +12,60 @@ # See the License for the specific language governing permissions and # limitations under the License. -locals { - additional_labels = tomap({ - for item in split(",", var.additional_labels) : - split("=", item)[0] => split("=", item)[1] - }) +provider "google" { + project = var.project_id } -resource "kubernetes_service" "inference_service" { - metadata { - name = "mistral-7b-instruct-service" - labels = { - app = "mistral-7b-instruct" - } - namespace = var.namespace - annotations = { - "cloud.google.com/load-balancer-type" = "Internal" - "cloud.google.com/neg" = "{\"ingress\":true}" - } - } - spec { - selector = { - app = "mistral-7b-instruct" - } - session_affinity = "ClientIP" - port { - protocol = "TCP" - port = 80 - target_port = 8080 - } - - type = "LoadBalancer" - } +provider "google-beta" { + project = var.project_id } -resource "kubernetes_deployment" "inference_deployment" { - timeouts { - create = "30m" - } - metadata { - name = "mistral-7b-instruct" - namespace = var.namespace - labels = merge({ - app = "mistral-7b-instruct" - }, local.additional_labels) - } - - spec { - # It takes more than 10m for the deployment to be ready on Autopilot cluster - # Set the progress deadline to 30m to avoid the deployment controller - # considering the deployment to be failed - progress_deadline_seconds = 1800 - replicas = 1 - - selector { - match_labels = merge({ - app = "mistral-7b-instruct" - }, local.additional_labels) - } - - template { - metadata { - labels = merge({ - app = "mistral-7b-instruct" - }, local.additional_labels) - } - - spec { - init_container { - name = "download-model" - image = "google/cloud-sdk:473.0.0-alpine" - command = ["gsutil", "cp", "-r", "gs://vertex-model-garden-public-us/mistralai/Mistral-7B-Instruct-v0.1/", "/model-data/"] - volume_mount { - mount_path = "/model-data" - name = "model-storage" - } - } - container { - image = "ghcr.io/huggingface/text-generation-inference:1.1.0" - name = "mistral-7b-instruct" - - port { - name = "metrics" - container_port = 8080 - protocol = "TCP" - } - - args = ["--model-id", "$(MODEL_ID)"] - - env { - name = "MODEL_ID" - value = "/model/Mistral-7B-Instruct-v0.1" - } - - env { - # Extends the max size of the prompt we can send to the service, - # so that we can augment prompts and add chat history without causing errors. - name = "MAX_INPUT_LENGTH" - value = 3072 - } - - env { - # Extends the overall context window (including length of prompt & response combined) - # Both this limit and MAX_INPUT_LENGTH need to be increased to enable RAG and chat history. - name = "MAX_TOTAL_TOKENS" - value = 4096 - } - - env { - name = "NUM_SHARD" - value = "2" - } +data "google_client_config" "default" {} - env { - name = "PORT" - value = "8080" - } - - resources { - limits = { - "nvidia.com/gpu" = "2" - } - requests = { - # Sufficient storage to fit the Mistral-7B-Instruct-v0.1 model - "ephemeral-storage" = "20Gi" - "nvidia.com/gpu" = "2" - } - } - - volume_mount { - mount_path = "/dev/shm" - name = "dshm" - } - - volume_mount { - mount_path = "/data" - name = "data" - } - - volume_mount { - mount_path = "/model" - name = "model-storage" - read_only = "true" - } - - #liveness_probe { - #http_get { - #path = "/" - #port = 8080 - - #http_header { - #name = "X-Custom-Header" - #value = "Awesome" - #} - #} +data "google_project" "project" { + project_id = var.project_id +} - #initial_delay_seconds = 3 - #period_seconds = 3 - #} - } +data "google_container_cluster" "my_cluster" { + name = var.cluster_name + location = var.location + project = var.project_id +} - volume { - name = "dshm" - empty_dir { - medium = "Memory" - } - } +locals { + ca_certificate = base64decode( + data.google_container_cluster.my_cluster.master_auth[0].cluster_ca_certificate, + ) + host = "https://${data.google_container_cluster.my_cluster.endpoint}" +} +provider "kubernetes" { + host = local.host + token = data.google_client_config.default.access_token + cluster_ca_certificate = local.ca_certificate +} - volume { - name = "data" - empty_dir {} - } +provider "helm" { + kubernetes { + host = local.host + token = data.google_client_config.default.access_token + cluster_ca_certificate = local.ca_certificate + } +} - volume { - name = "model-storage" - empty_dir {} - } +module "inference-server" { + source = "../../modules/inference-service" + namespace = var.namespace + additional_labels = var.additional_labels + autopilot_cluster = var.autopilot_cluster +} - node_selector = merge({ - "cloud.google.com/gke-accelerator" = "nvidia-l4" - }, var.autopilot_cluster ? { - "cloud.google.com/gke-ephemeral-storage-local-ssd" = "true" - "cloud.google.com/compute-class" = "Accelerator" - } : {}) - } - } - } -} \ No newline at end of file +resource "helm_release" "gmp-engine" { + name = "gmp-engine" + chart = "${path.module}/../../charts/gmp-engine/" + namespace = var.namespace + # Timeout is increased to guarantee sufficient scale-up time for Autopilot nodes. + timeout = 1200 + values = [ + "${file("${path.module}/podmonitoring.yaml")}" + ] +} diff --git a/tutorials-and-examples/hf-tgi/outputs.tf b/tutorials-and-examples/hf-tgi/outputs.tf index 7078bac0d..2de1448ec 100644 --- a/tutorials-and-examples/hf-tgi/outputs.tf +++ b/tutorials-and-examples/hf-tgi/outputs.tf @@ -14,15 +14,15 @@ output "inference_service_name" { description = "Name of model inference service" - value = kubernetes_service.inference_service.metadata[0].name + value = module.inference-server.inference_service_name } output "inference_service_namespace" { description = "Namespace of model inference service" - value = kubernetes_service.inference_service.metadata[0].namespace + value = module.inference-server.inference_service_namespace } output "inference_service_endpoint" { description = "Endpoint of model inference service" - value = kubernetes_service.inference_service.status != null ? (kubernetes_service.inference_service.status[0].load_balancer != null ? "${kubernetes_service.inference_service.status[0].load_balancer[0].ingress[0].ip}" : "") : "" + value = module.inference-server.inference_service_endpoint } diff --git a/tutorials-and-examples/hf-tgi/podmonitoring.yaml b/tutorials-and-examples/hf-tgi/podmonitoring.yaml new file mode 100644 index 000000000..807fbe706 --- /dev/null +++ b/tutorials-and-examples/hf-tgi/podmonitoring.yaml @@ -0,0 +1,6 @@ +podMonitoring: +- name: mistral-7b-instruct + selector: + app: mistral-7b-instruct + port: metrics + interval: 30s diff --git a/tutorials-and-examples/hf-tgi/variables.tf b/tutorials-and-examples/hf-tgi/variables.tf index 76c755c9d..16965451c 100644 --- a/tutorials-and-examples/hf-tgi/variables.tf +++ b/tutorials-and-examples/hf-tgi/variables.tf @@ -12,6 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +variable "project_id" { + type = string + description = "GCP project ID" +} + +variable "cluster_name" { + type = string + description = "GKE cluster name" + default = "l4-demo" +} + +variable "location" { + type = string + description = "Location of GKE cluster" + default = "us-central1" +} + variable "namespace" { type = string description = "Kubernetes namespace where resources are deployed" @@ -26,5 +43,6 @@ variable "additional_labels" { } variable "autopilot_cluster" { - type = bool + type = bool + default = false } diff --git a/tutorials-and-examples/hf-tgi/versions.tf b/tutorials-and-examples/hf-tgi/versions.tf index cdf57c280..b20d44bc9 100644 --- a/tutorials-and-examples/hf-tgi/versions.tf +++ b/tutorials-and-examples/hf-tgi/versions.tf @@ -17,6 +17,10 @@ terraform { google = { source = "hashicorp/google" } + helm = { + source = "hashicorp/helm" + version = "~> 2.14.0" + } kubernetes = { source = "hashicorp/kubernetes" } From a9895d6150b756e731232cefc2d1c6685624296c Mon Sep 17 00:00:00 2001 From: German Grandas Date: Tue, 20 Aug 2024 15:16:04 -0500 Subject: [PATCH 17/46] fixing conflicts with remote branch --- .../container/rag_langchain/rag_chain.py | 136 ++++++++++++++++++ applications/rag/frontend/main.tf | 15 +- applications/rag/main.tf | 14 -- 3 files changed, 149 insertions(+), 16 deletions(-) create mode 100644 applications/rag/frontend/container/rag_langchain/rag_chain.py diff --git a/applications/rag/frontend/container/rag_langchain/rag_chain.py b/applications/rag/frontend/container/rag_langchain/rag_chain.py new file mode 100644 index 000000000..807c875c8 --- /dev/null +++ b/applications/rag/frontend/container/rag_langchain/rag_chain.py @@ -0,0 +1,136 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import (Dict) +from cloud_sql.cloud_sql import CHAT_HISTORY_TABLE_NAME, init_connection_pool, create_sync_postgres_engine, CustomVectorStore +from google.cloud.sql.connector import Connector +from langchain_community.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnableParallel, RunnableLambda +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_google_cloud_sql_pg import PostgresChatMessageHistory + +QUESTION = "question" +HISTORY = "history" +CONTEXT = "context" + +INFERENCE_ENDPOINT=os.environ.get('INFERENCE_ENDPOINT', '127.0.0.1:8081') +SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings + + +# TODO use a chat model instead of an LLM in the chain. Convert the prompt to a chat prompt template +# prompt = ChatPromptTemplate.from_messages( +# [ +# ("system", """You help everyone by answering questions, and improve your answers from previous answers in history. +# You stick to the facts by basing your answers off of the context provided:"""), +# MessagesPlaceholder(variable_name="history"), +# MessagesPlaceholder(variable_name="context"), +# ("human", "{question}"), +# ] +# ) +template = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. +Improve upon your previous answers using History, a list of messages. +Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. +Stick to the facts by basing your answers off of the Context provided. +Be brief in answering. +History: {""" + HISTORY + "}\n\nContext: {" + CONTEXT + "}\n\nQuestion: {" + QUESTION + "}\n" + +prompt = PromptTemplate(template=template, input_variables=[HISTORY, CONTEXT, QUESTION]) + +engine = create_sync_postgres_engine() +# TODO: Dict is not safe for multiprocessing. Introduce a cache using Flask-caching or libcache +# The in-memory SimpleCache implementations for each of these libraries is not safe either. +# Consider redis or memcached (e.g., Memorystore) +# chat_history_map: Dict[str, PostgresChatMessageHistory] = {} + +def get_chat_history(session_id: str) -> PostgresChatMessageHistory: + history = PostgresChatMessageHistory.create_sync( + engine, + session_id=session_id, + table_name = CHAT_HISTORY_TABLE_NAME + ) + + print(f"Retrieving history for session {session_id} with {len(history.messages)}") + return history + +def clear_chat_history(session_id: str): + history = PostgresChatMessageHistory.create_sync( + engine, + session_id=session_id, + table_name = CHAT_HISTORY_TABLE_NAME + ) + history.clear() + + +#TODO: limit number of tokens in prompt to MAX_INPUT_LENGTH +# (as specified in hugging face TGI input parameter) + +def create_chain() -> RunnableWithMessageHistory: + # TODO HuggingFaceTextGenInference class is deprecated. + # The warning is: + # The class `langchain_community.llms.huggingface_text_gen_inference.HuggingFaceTextGenInference` + # was deprecated in langchain-community 0.0.21 and will be removed in 0.2.0. Use HuggingFaceEndpoint instead + # The replacement is HuggingFace Endoint, which requires a huggingface + # hub API token. Either need to add the token to the environment, or need to find a method to call TGI + # without the token. + # Example usage of HuggingFaceEndpoint: + # llm = HuggingFaceEndpoint( + # endpoint_url=f'http://{INFERENCE_ENDPOINT}/', + # max_new_tokens=512, + # top_k=10, + # top_p=0.95, + # typical_p=0.95, + # temperature=0.01, + # repetition_penalty=1.03, + # huggingfacehub_api_token="my-api-key" + # ) + # TODO: Give guidance on what these parameters should be and describe why these values were chosen. + model = HuggingFaceTextGenInference( + inference_server_url=f'http://{INFERENCE_ENDPOINT}/', + max_new_tokens=512, + top_k=10, + top_p=0.95, + typical_p=0.95, + temperature=0.01, + repetition_penalty=1.03, + ) + + langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) + vector_store = CustomVectorStore(langchain_embed, init_connection_pool(Connector())) + retriever = vector_store.as_retriever() + + setup_and_retrieval = RunnableParallel( + { + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]) + } + ) + chain = setup_and_retrieval | prompt | model + chain_with_history = RunnableWithMessageHistory( + chain, + get_chat_history, + input_messages_key=QUESTION, + history_messages_key=HISTORY, + output_messages_key="output" + ) + return chain_with_history + +def take_chat_turn(chain: RunnableWithMessageHistory, session_id: str, query_text: str) -> str: + #TODO limit the number of history messages + config = {"configurable": {"session_id": session_id}} + result = chain.invoke({"question": query_text}, config) + return str(result) \ No newline at end of file diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 4b7f73254..06f753d49 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -123,8 +123,19 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { } env { - name = "PROJECT_ID" - value = "projects/${var.project_id}" + name = "PROJECT_ID" + #value = "projects/${var.project_id}" + value = var.project_id + } + + env { + name = "REGION" + value = var.region + } + + env { + name = "INSTANCE" + value = var.cloudsql_instance } env { diff --git a/applications/rag/main.tf b/applications/rag/main.tf index 45d319bd6..485389ae5 100644 --- a/applications/rag/main.tf +++ b/applications/rag/main.tf @@ -325,17 +325,3 @@ module "frontend" { members_allowlist = var.frontend_members_allowlist != "" ? split(",", var.frontend_members_allowlist) : [] depends_on = [module.namespace] } - -resource "helm_release" "gmp-apps" { - name = "gmp-apps" - provider = helm.rag - chart = "../../charts/gmp-engine/" - namespace = local.kubernetes_namespace - # Timeout is increased to guarantee sufficient scale-up time for Autopilot nodes. - timeout = 1200 - depends_on = [module.inference-server, module.frontend] - values = [ - "${file("${path.module}/podmonitoring.yaml")}" - ] -} - From cd95c982dd5c3aa5030ee97da417271d51b36531 Mon Sep 17 00:00:00 2001 From: German Grandas Date: Tue, 20 Aug 2024 15:16:45 -0500 Subject: [PATCH 18/46] fixing conflicts with remote branch --- applications/rag/frontend/main.tf | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 06f753d49..f27acb2f1 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -109,8 +109,10 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { spec { service_account_name = var.google_service_account container { - image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:335b60a0775abecd7bfcdde4bd051196d692949952aa3afb76fc934fc8d38842" - name = "rag-frontend" + image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" + # Built from local code. Revert before submitting. + # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:108bb16ee2278255c80524fce125ef349c494cb5bc4ca77dbde5048b8f9448c1" + name = "rag-frontend" port { container_port = 8080 From bc8d745701ea492a43de59924cff8502f4844c61 Mon Sep 17 00:00:00 2001 From: German Grandas Date: Tue, 20 Aug 2024 15:18:12 -0500 Subject: [PATCH 19/46] fixing conflicts with remote branch --- .../application/rag_langchain/rag_chain.py | 238 +++++++++--------- .../application/utils/cloud_sql_utils.py | 2 +- .../application/vector_storages/cloud_sql.py | 4 + 3 files changed, 124 insertions(+), 120 deletions(-) diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index b8cf7b458..41e4629b8 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -1,119 +1,119 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import logging - -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import RunnableParallel, RunnableLambda -from langchain_core.runnables.history import RunnableWithMessageHistory - -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_google_cloud_sql_pg import PostgresChatMessageHistory - - -from application.utils import ( - create_sync_postgres_engine -) -from application.rag_langchain.huggingface_inference_model import ( - HuggingFaceCustomChatModel, -) -from application.vector_storages import CloudSQLVectorStore - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - -CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") - -QUESTION = "input" -HISTORY = "chat_history" -CONTEXT = "context" - -SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings - -template_str = """Answer the question given by the user in no more than 2 sentences. -Use the provided context to improve upon your previous answers. Stick to the facts and be brief. Avoid conversational format. -\n\n -Context: {context} -""" - -prompt = ChatPromptTemplate.from_messages( - [ - ("system", template_str), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] -) - -engine = create_sync_postgres_engine() - - -def get_chat_history(session_id: str) -> PostgresChatMessageHistory: - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - - logging.info( - f"Retrieving history for session {session_id} with {len(history.messages)}" - ) - return history - - -def clear_chat_history(session_id: str): - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - history.clear() - - -def create_chain() -> RunnableWithMessageHistory: - try: - model = HuggingFaceCustomChatModel() - - langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CloudSQLVectorStore(langchain_embed, engine) - - retriever = vector_store.as_retriever() - - setup_and_retrieval = RunnableParallel( - { - "context": retriever, - QUESTION: RunnableLambda(lambda d: d[QUESTION]), - HISTORY: RunnableLambda(lambda d: d[HISTORY]), - } - ) - - chain = setup_and_retrieval | prompt | model - chain_with_history = RunnableWithMessageHistory( - chain, - get_chat_history, - input_messages_key=QUESTION, - history_messages_key=HISTORY, - output_messages_key="output", - ) - return chain_with_history - except Exception as e: - logging.info(e) - raise e - -def take_chat_turn( - chain: RunnableWithMessageHistory, session_id: str, query_text: str -) -> str: - try: - config = {"configurable": {"session_id": session_id}} - result = chain.invoke({"input": query_text}, config=config) - return result - except Exception as e: - logging.info(e) - raise e \ No newline at end of file +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import logging + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableParallel, RunnableLambda +from langchain_core.runnables.history import RunnableWithMessageHistory + +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_google_cloud_sql_pg import PostgresChatMessageHistory + + +from application.utils import ( + create_sync_postgres_engine +) +from application.rag_langchain.huggingface_inference_model import ( + HuggingFaceCustomChatModel, +) +from application.vector_storages import CloudSQLVectorStore + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") + +QUESTION = "input" +HISTORY = "chat_history" +CONTEXT = "context" + +SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings + +template_str = """Answer the question given by the user in no more than 2 sentences. +Use the provided context to improve upon your previous answers. Stick to the facts and be brief. Avoid conversational format. +\n\n +Context: {context} +""" + +prompt = ChatPromptTemplate.from_messages( + [ + ("system", template_str), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) + +engine = create_sync_postgres_engine() + + +def get_chat_history(session_id: str) -> PostgresChatMessageHistory: + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + + logging.info( + f"Retrieving history for session {session_id} with {len(history.messages)}" + ) + return history + + +def clear_chat_history(session_id: str): + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + history.clear() + + +def create_chain() -> RunnableWithMessageHistory: + try: + model = HuggingFaceCustomChatModel() + + langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) + vector_store = CloudSQLVectorStore(langchain_embed, engine) + + retriever = vector_store.as_retriever() + + setup_and_retrieval = RunnableParallel( + { + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]), + } + ) + + chain = setup_and_retrieval | prompt | model + chain_with_history = RunnableWithMessageHistory( + chain, + get_chat_history, + input_messages_key=QUESTION, + history_messages_key=HISTORY, + output_messages_key="output", + ) + return chain_with_history + except Exception as e: + logging.info(e) + raise e + +def take_chat_turn( + chain: RunnableWithMessageHistory, session_id: str, query_text: str +) -> str: + try: + config = {"configurable": {"session_id": session_id}} + result = chain.invoke({"input": query_text}, config=config) + return result + except Exception as e: + logging.info(e) + raise e diff --git a/applications/rag/frontend/container/application/utils/cloud_sql_utils.py b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py index d02484ef8..5377aaeee 100644 --- a/applications/rag/frontend/container/application/utils/cloud_sql_utils.py +++ b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py @@ -75,4 +75,4 @@ def create_sync_postgres_engine(): except Exception as e: logging.info(f"Error: {e}") - return engine \ No newline at end of file + return engine diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py index 3fd6f2cde..dab9baedd 100644 --- a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -78,7 +78,11 @@ def similarity_search( query_input = query.get("input") query_vector = self.embeddings_service.embed_query(query_input) +<<<<<<< HEAD docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) +======= + docs = self.vector_store.similarity_search_by_vector(query_vector, k=4) +>>>>>>> 2d652e30 (Rag langchain chat history (#755)) return docs except Exception as err: From e9beeef68537976e46d414ab03d06628a241d76b Mon Sep 17 00:00:00 2001 From: German Grandas Date: Tue, 20 Aug 2024 15:19:13 -0500 Subject: [PATCH 20/46] fixing conflicts applying rebase --- .../container/application/__init__.py | 2 +- .../container/application/models/__init__.py | 17 -- .../application/models/vector_embeddings.py | 31 --- .../application/rag_langchain/rag_chain.py | 238 +++++++++--------- .../application/utils/cloud_sql_utils.py | 2 +- .../application/vector_storages/cloud_sql.py | 30 +-- .../vector_storages/custom_vector_storage.py | 92 ------- applications/rag/frontend/main.tf | 15 +- ...-metrics_system_auth-delegator.yaml.tftpl} | 0 9 files changed, 148 insertions(+), 279 deletions(-) delete mode 100644 applications/rag/frontend/container/application/models/__init__.py delete mode 100644 applications/rag/frontend/container/application/models/vector_embeddings.py delete mode 100644 applications/rag/frontend/container/application/vector_storages/custom_vector_storage.py rename modules/custom-metrics-stackdriver-adapter/templates/{clusterrolebinding_custom-metrics:system:auth-delegator.yaml.tftpl => clusterrolebinding_custom-metrics_system_auth-delegator.yaml.tftpl} (100%) diff --git a/applications/rag/frontend/container/application/__init__.py b/applications/rag/frontend/container/application/__init__.py index 64c71bccc..da1a54e23 100644 --- a/applications/rag/frontend/container/application/__init__.py +++ b/applications/rag/frontend/container/application/__init__.py @@ -19,6 +19,6 @@ def create_app(): app = Flask(__name__, static_folder='static', template_folder='templates') app.jinja_env.trim_blocks = True app.jinja_env.lstrip_blocks = True - app.config['SECRET_KEY'] = os.environ.get("SECRET_KEY") + app.config['SECRET_KEY'] = os.environ.get("APPLICATION_SECRET_KEY") return app diff --git a/applications/rag/frontend/container/application/models/__init__.py b/applications/rag/frontend/container/application/models/__init__.py deleted file mode 100644 index 8d1d56862..000000000 --- a/applications/rag/frontend/container/application/models/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .vector_embeddings import VectorEmbeddings - -__all__ = ["VectorEmbeddings"] diff --git a/applications/rag/frontend/container/application/models/vector_embeddings.py b/applications/rag/frontend/container/application/models/vector_embeddings.py deleted file mode 100644 index fd42ea5b2..000000000 --- a/applications/rag/frontend/container/application/models/vector_embeddings.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from sqlalchemy import Column, String, Text -from sqlalchemy.orm import mapped_column, declarative_base -from pgvector.sqlalchemy import Vector - -Base = declarative_base() - -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") - - -class VectorEmbeddings(Base): - __tablename__ = VECTOR_EMBEDDINGS_TABLE_NAME - - id = Column(String(255), primary_key=True) - text = Column(Text) - text_embedding = mapped_column(Vector(384)) diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index 41e4629b8..901477365 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -1,119 +1,119 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import logging - -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import RunnableParallel, RunnableLambda -from langchain_core.runnables.history import RunnableWithMessageHistory - -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_google_cloud_sql_pg import PostgresChatMessageHistory - - -from application.utils import ( - create_sync_postgres_engine -) -from application.rag_langchain.huggingface_inference_model import ( - HuggingFaceCustomChatModel, -) -from application.vector_storages import CloudSQLVectorStore - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - -CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") - -QUESTION = "input" -HISTORY = "chat_history" -CONTEXT = "context" - -SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings - -template_str = """Answer the question given by the user in no more than 2 sentences. -Use the provided context to improve upon your previous answers. Stick to the facts and be brief. Avoid conversational format. -\n\n -Context: {context} -""" - -prompt = ChatPromptTemplate.from_messages( - [ - ("system", template_str), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] -) - -engine = create_sync_postgres_engine() - - -def get_chat_history(session_id: str) -> PostgresChatMessageHistory: - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - - logging.info( - f"Retrieving history for session {session_id} with {len(history.messages)}" - ) - return history - - -def clear_chat_history(session_id: str): - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - history.clear() - - -def create_chain() -> RunnableWithMessageHistory: - try: - model = HuggingFaceCustomChatModel() - - langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CloudSQLVectorStore(langchain_embed, engine) - - retriever = vector_store.as_retriever() - - setup_and_retrieval = RunnableParallel( - { - "context": retriever, - QUESTION: RunnableLambda(lambda d: d[QUESTION]), - HISTORY: RunnableLambda(lambda d: d[HISTORY]), - } - ) - - chain = setup_and_retrieval | prompt | model - chain_with_history = RunnableWithMessageHistory( - chain, - get_chat_history, - input_messages_key=QUESTION, - history_messages_key=HISTORY, - output_messages_key="output", - ) - return chain_with_history - except Exception as e: - logging.info(e) - raise e - -def take_chat_turn( - chain: RunnableWithMessageHistory, session_id: str, query_text: str -) -> str: - try: - config = {"configurable": {"session_id": session_id}} - result = chain.invoke({"input": query_text}, config=config) - return result - except Exception as e: - logging.info(e) - raise e +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import logging + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableParallel, RunnableLambda +from langchain_core.runnables.history import RunnableWithMessageHistory + +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_google_cloud_sql_pg import PostgresChatMessageHistory + + +from application.utils import ( + create_sync_postgres_engine +) +from application.rag_langchain.huggingface_inference_model import ( + HuggingFaceCustomChatModel, +) +from application.vector_storages import CloudSQLVectorStore + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") + +QUESTION = "input" +HISTORY = "chat_history" +CONTEXT = "context" + +SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings + +template_str = """Answer the question given by the user in no more than 2 sentences. +Use the provided context to improve upon your previous answers. Stick to the facts and be brief. Avoid conversational format. +\n\n +Context: {context} +""" + +prompt = ChatPromptTemplate.from_messages( + [ + ("system", template_str), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) + +engine = create_sync_postgres_engine() + + +def get_chat_history(session_id: str) -> PostgresChatMessageHistory: + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + + logging.info( + f"Retrieving history for session {session_id} with {len(history.messages)}" + ) + return history + + +def clear_chat_history(session_id: str): + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + history.clear() + + +def create_chain() -> RunnableWithMessageHistory: + try: + model = HuggingFaceCustomChatModel() + + langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) + vector_store = CloudSQLVectorStore(langchain_embed, engine) + + retriever = vector_store.as_retriever() + + setup_and_retrieval = RunnableParallel( + { + "context": retriever, + QUESTION: RunnableLambda(lambda d: d[QUESTION]), + HISTORY: RunnableLambda(lambda d: d[HISTORY]), + } + ) + + chain = setup_and_retrieval | prompt | model + chain_with_history = RunnableWithMessageHistory( + chain, + get_chat_history, + input_messages_key=QUESTION, + history_messages_key=HISTORY, + output_messages_key="output", + ) + return chain_with_history + except Exception as e: + logging.info(e) + raise e + +def take_chat_turn( + chain: RunnableWithMessageHistory, session_id: str, query_text: str +) -> str: + try: + config = {"configurable": {"session_id": session_id}} + result = chain.invoke({"input": query_text}, config=config) + return result + except Exception as e: + logging.info(e) + raise e diff --git a/applications/rag/frontend/container/application/utils/cloud_sql_utils.py b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py index 5377aaeee..44a2ca061 100644 --- a/applications/rag/frontend/container/application/utils/cloud_sql_utils.py +++ b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py @@ -38,7 +38,7 @@ DB_NAME = os.environ.get("DB_NAME", "pgvector-database") VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") -CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") +CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "chat_history_store") VECTOR_DIMENSION = os.environ.get("VECTOR_DIMENSION", 384) diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py index dab9baedd..6070ca817 100644 --- a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -36,15 +36,6 @@ class CloudSQLVectorStore(VectorStore): - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ): - raise NotImplementedError def __init__(self, embedding_provider, engine): self.vector_store = PostgresVectorStore.create_sync( @@ -57,7 +48,16 @@ def __init__(self, embedding_provider, engine): ) self.embeddings_service = embedding_provider - # TODO implement + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ): + raise NotImplementedError + def add_texts( self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any ) -> List[str]: @@ -67,16 +67,12 @@ def add_texts( self.vector_store.add_documents(splits, ids) except Exception as e: logging.info(f"Error: {e}") - raise e - - # TODO implement similarity search with cosine similarity threshold + raise Exception(f"Error adding texts: {err}") def similarity_search( - self, query: dict, k: int = 4, **kwargs: Any + self, query_input: dict, k: int = 4, **kwargs: Any ) -> List[Document]: try: - - query_input = query.get("input") query_vector = self.embeddings_service.embed_query(query_input) <<<<<<< HEAD docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) @@ -86,4 +82,4 @@ def similarity_search( return docs except Exception as err: - raise Exception(f"General error: {err}") + raise Exception(f"Error on similarity search: {err}") diff --git a/applications/rag/frontend/container/application/vector_storages/custom_vector_storage.py b/applications/rag/frontend/container/application/vector_storages/custom_vector_storage.py deleted file mode 100644 index 01a41b8c7..000000000 --- a/applications/rag/frontend/container/application/vector_storages/custom_vector_storage.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -from typing import (List, Optional, Iterable, Any) - -import pg8000 -import sqlalchemy -from sqlalchemy.engine import Engine - -from langchain_core.vectorstores import VectorStore -from langchain_core.embeddings import Embeddings -from langchain_core.documents import Document -from langchain.text_splitter import CharacterTextSplitter - -VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get('TABLE_NAME', '') -INSTANCE_CONNECTION_NAME = os.environ.get('INSTANCE_CONNECTION_NAME', '') - -class CustomVectorStore(VectorStore): - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ): - raise NotImplementedError - - def __init__(self, embedding: Embeddings, engine: Engine): - self.embedding = embedding - self.engine = engine - self.text_splitter = CharacterTextSplitter( - separator="\n\n", - chunk_size=1024, - chunk_overlap=200, - ) - @property - def embeddings(self) -> Embeddings: - return self.embedding - - - # TODO implement - def add_texts(self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any) -> List[str]: - with self.engine.connect() as conn: - try: - for raw_text in texts: - texts = self.text_splitter.split_text(raw_text) - - embeddings = self.embedding.encode(texts).tolist() - embeddings = embeddings.tobytes() - query_request = "INSERT INTO documents (text, embedding) VALUES (%s, %s)", - conn.execute(sqlalchemy.text(query_request),(texts, embeddings)) - conn.commit() - - except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: - message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except sqlalchemy.exc.DatabaseError as err: - message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except Exception as err: - raise Exception(f"General error: {err}") - - #TODO implement similarity search with cosine similarity threshold - - def similarity_search(self, query: dict, k: int = 4, **kwargs: Any) -> List[Document]: - with self.engine.connect() as conn: - try: - q = query["question"] - # embed query & fetch matches - query_emb = self.embedding.embed_query(q) - emb_str = ",".join(map(str, query_emb)) - query_request = f"""SELECT id, text, 1 - ('[{emb_str}]' <=> text_embedding) AS cosine_similarity - FROM {VECTOR_EMBEDDINGS_TABLE_NAME} - ORDER BY cosine_similarity DESC LIMIT {k};""" - query_results = conn.execute(sqlalchemy.text(query_request)).fetchall() - print(f"GOT {len(query_results)} results") - conn.commit() - - if not query_results: - message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} returned empty result" - raise ValueError(message) - except sqlalchemy.exc.DBAPIError or pg8000.exceptions.DatabaseError as err: - message = f"Table {VECTOR_EMBEDDINGS_TABLE_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except sqlalchemy.exc.DatabaseError as err: - message = f"Database {INSTANCE_CONNECTION_NAME} does not exist: {err}" - raise sqlalchemy.exc.DataError(message) - except Exception as err: - raise Exception(f"General error: {err}") - - #convert query results into List[Document] - texts = [result[1] for result in query_results] - return [Document(page_content=text) for text in texts] \ No newline at end of file diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index f27acb2f1..7888c6bc2 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -23,6 +23,14 @@ locals { }) } +resource "random_string" "application_secret_key" { + length = var.project_id + lower = true + numeric = true + special = false + upper = false +} + # IAP Section: Creates the GKE components module "iap_auth" { count = var.add_auth ? 1 : 0 @@ -161,10 +169,15 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { } env { - name = "TABLE_NAME" + name = "EMBEDDINGS_TABLE_NAME" value = var.dataset_embeddings_table_name } + env { + name = "APPLICATION_SECRET_KEY" + value = random_string.application_secret_key.result + } + resources { limits = { cpu = "3" diff --git a/modules/custom-metrics-stackdriver-adapter/templates/clusterrolebinding_custom-metrics:system:auth-delegator.yaml.tftpl b/modules/custom-metrics-stackdriver-adapter/templates/clusterrolebinding_custom-metrics_system_auth-delegator.yaml.tftpl similarity index 100% rename from modules/custom-metrics-stackdriver-adapter/templates/clusterrolebinding_custom-metrics:system:auth-delegator.yaml.tftpl rename to modules/custom-metrics-stackdriver-adapter/templates/clusterrolebinding_custom-metrics_system_auth-delegator.yaml.tftpl From eb9ab02a15266e4cc1efabb774d57ec7aaf56d92 Mon Sep 17 00:00:00 2001 From: German Grandas Date: Tue, 20 Aug 2024 15:31:34 -0500 Subject: [PATCH 21/46] Updating files based on reviewer comments --- .../application/rag_langchain/rag_chain.py | 23 +++++++++++-------- .../application/vector_storages/cloud_sql.py | 7 ++---- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index 901477365..36cc10859 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -60,15 +60,18 @@ def get_chat_history(session_id: str) -> PostgresChatMessageHistory: - history = PostgresChatMessageHistory.create_sync( - engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME - ) - - logging.info( - f"Retrieving history for session {session_id} with {len(history.messages)}" - ) - return history + try: + history = PostgresChatMessageHistory.create_sync( + engine, session_id=session_id, table_name=CHAT_HISTORY_TABLE_NAME + ) + logging.info( + f"Retrieving history for session {session_id} with {len(history.messages)}" + ) + return history + except Exception as e: + logging.error(e) + return None def clear_chat_history(session_id: str): history = PostgresChatMessageHistory.create_sync( @@ -104,7 +107,7 @@ def create_chain() -> RunnableWithMessageHistory: ) return chain_with_history except Exception as e: - logging.info(e) + logging.error(e) raise e def take_chat_turn( @@ -115,5 +118,5 @@ def take_chat_turn( result = chain.invoke({"input": query_text}, config=config) return result except Exception as e: - logging.info(e) + logging.error(e) raise e diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py index 6070ca817..b37724212 100644 --- a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -70,15 +70,12 @@ def add_texts( raise Exception(f"Error adding texts: {err}") def similarity_search( - self, query_input: dict, k: int = 4, **kwargs: Any + self, query: dict, k: int = 4, **kwargs: Any ) -> List[Document]: try: + query_input = query.get("input") query_vector = self.embeddings_service.embed_query(query_input) -<<<<<<< HEAD docs = self.vector_store.similarity_search_by_vector(query_vector, k=k) -======= - docs = self.vector_store.similarity_search_by_vector(query_vector, k=4) ->>>>>>> 2d652e30 (Rag langchain chat history (#755)) return docs except Exception as err: From dff8d942c0b3a539073ad54f5569041199ef04f9 Mon Sep 17 00:00:00 2001 From: German Grandas Date: Tue, 20 Aug 2024 15:41:39 -0500 Subject: [PATCH 22/46] reverting change on cloudbuild.yaml file --- cloudbuild.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 53b405e8e..0a688ae71 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -255,7 +255,7 @@ steps: sleep 5s cd /workspace/applications/rag/tests - # python3 test_frontend.py "127.0.0.1:8081" + python3 test_frontend.py "127.0.0.1:8081" echo "pass" > /workspace/rag_frontend_result.txt cd /workspace/ From 138920f4e80ed038d42abe93c801bde307594e59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 26 Aug 2024 11:22:01 -0500 Subject: [PATCH 23/46] Reverting comment of line --- cloudbuild.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 0a688ae71..d3c22f362 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -269,7 +269,7 @@ steps: kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- jupyter nbconvert --to script /data/rag-kaggle-ray-sql-interactive.ipynb kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- ipython /data/rag-kaggle-ray-sql-interactive.py - # python3 ./applications/rag/tests/test_rag.py "http://127.0.0.1:8081/prompt" Ignoring while the test approach is reviewed + python3 ./applications/rag/tests/test_rag.py "http://127.0.0.1:8081/prompt" echo "pass" > /workspace/rag_prompt_result.txt allowFailure: true From c8e5d3522869184ce7ec1016c43bb5e7c5d8627b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 26 Aug 2024 12:57:42 -0500 Subject: [PATCH 24/46] Updating length of variable --- applications/rag/frontend/main.tf | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 7888c6bc2..5bc4ad27f 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -24,7 +24,7 @@ locals { } resource "random_string" "application_secret_key" { - length = var.project_id + length = 8 lower = true numeric = true special = false @@ -226,4 +226,4 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { } } } -} \ No newline at end of file +} From 42618181d77e4931f7f85b54b912bc397b94cea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 4 Sep 2024 10:13:14 -0500 Subject: [PATCH 25/46] Updating rag frontend image. --- applications/rag/frontend/main.tf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 5bc4ad27f..d20597952 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -117,7 +117,7 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { spec { service_account_name = var.google_service_account container { - image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" + image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:ec0e7b1ce6d0f9570957dd7fb3dcf0a16259cba915570846b356a17d6e377c59" # Built from local code. Revert before submitting. # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:108bb16ee2278255c80524fce125ef349c494cb5bc4ca77dbde5048b8f9448c1" name = "rag-frontend" From a8258f1090dce282feddb29f41d8172493e32b18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 9 Sep 2024 14:18:51 +0000 Subject: [PATCH 26/46] updating rag frontend images with the latest changes --- applications/rag/frontend/main.tf | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 5bc4ad27f..18f7c759c 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -117,7 +117,8 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { spec { service_account_name = var.google_service_account container { - image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" + # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" + image = "us-docker.pkg.dev/globant-gke-ai-resources/gke-ai-text-to-text/gke-rag-frontend@sha256:e56c59747b1ecc192458a3fdd6c74ad6a2099eeabb61e3fd1eb5fc30a147ba1d" # Built from local code. Revert before submitting. # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:108bb16ee2278255c80524fce125ef349c494cb5bc4ca77dbde5048b8f9448c1" name = "rag-frontend" @@ -133,28 +134,18 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { } env { - name = "PROJECT_ID" + name = "GCP_PROJECT_ID" #value = "projects/${var.project_id}" value = var.project_id } env { - name = "REGION" + name = "CLOUDSQL_INSTANCE_REGION" value = var.region } env { - name = "INSTANCE" - value = var.cloudsql_instance - } - - env { - name = "REGION" - value = var.region - } - - env { - name = "INSTANCE" + name = "CLOUDSQL_INSTANCE" value = var.cloudsql_instance } From 4f025462e7e258d23edec5df469d8894a38f0205 Mon Sep 17 00:00:00 2001 From: German Grandas Date: Mon, 9 Sep 2024 10:55:00 -0500 Subject: [PATCH 27/46] Fixing issue with database connection --- .../rag/frontend/container/Dockerfile | 2 +- .../application/utils/cloud_sql_utils.py | 19 ++- .../application/vector_storages/cloud_sql.py | 5 +- .../container/rag_langchain/rag_chain.py | 136 ------------------ 4 files changed, 18 insertions(+), 144 deletions(-) delete mode 100644 applications/rag/frontend/container/rag_langchain/rag_chain.py diff --git a/applications/rag/frontend/container/Dockerfile b/applications/rag/frontend/container/Dockerfile index a9cd73be4..d7aa7b886 100644 --- a/applications/rag/frontend/container/Dockerfile +++ b/applications/rag/frontend/container/Dockerfile @@ -22,6 +22,6 @@ RUN pip install -r requirements.txt EXPOSE 8080 ENV FLASK_APP=/workspace/frontend/main.py -ENV PYTHONPATH=. +ENV PYTHONPATH=/workspace/frontend/ # Run the application with Gunicorn CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:8080", "main:app"] diff --git a/applications/rag/frontend/container/application/utils/cloud_sql_utils.py b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py index 44a2ca061..009bbcb94 100644 --- a/applications/rag/frontend/container/application/utils/cloud_sql_utils.py +++ b/applications/rag/frontend/container/application/utils/cloud_sql_utils.py @@ -22,13 +22,18 @@ import os import logging +import google.cloud.logging as gcloud_logging + from google.cloud.sql.connector import IPTypes -from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore +from langchain_google_cloud_sql_pg import PostgresEngine logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) +gcloud_logging_client = gcloud_logging.Client() +gcloud_logging_client.setup_logging() + ENVIRONMENT = os.environ.get("ENVIRONMENT") @@ -38,7 +43,7 @@ DB_NAME = os.environ.get("DB_NAME", "pgvector-database") VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get("EMBEDDINGS_TABLE_NAME", "") -CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "chat_history_store") +CHAT_HISTORY_TABLE_NAME = os.environ.get("CHAT_HISTORY_TABLE_NAME", "message_store") VECTOR_DIMENSION = os.environ.get("VECTOR_DIMENSION", 384) @@ -66,13 +71,17 @@ def create_sync_postgres_engine(): ip_type=IPTypes.PUBLIC if ENVIRONMENT == "development" else IPTypes.PRIVATE, ) try: - engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) engine.init_vectorstore_table( VECTOR_EMBEDDINGS_TABLE_NAME, vector_size=VECTOR_DIMENSION, overwrite_existing=False, ) - except Exception as e: - logging.info(f"Error: {e}") + except Exception as err: + logging.error(f"Error: {err}") + + try: + engine.init_chat_history_table(table_name=CHAT_HISTORY_TABLE_NAME) + except Exception as err: + logging.error(f"Error: {err}") return engine diff --git a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py index b37724212..603b2ec46 100644 --- a/applications/rag/frontend/container/application/vector_storages/cloud_sql.py +++ b/applications/rag/frontend/container/application/vector_storages/cloud_sql.py @@ -65,8 +65,8 @@ def add_texts( splits = self.splitter.split_documents(texts) ids = [str(uuid.uuid4()) for _ in range(len(splits))] self.vector_store.add_documents(splits, ids) - except Exception as e: - logging.info(f"Error: {e}") + except Exception as err: + logging.error(f"Error: {err}") raise Exception(f"Error adding texts: {err}") def similarity_search( @@ -79,4 +79,5 @@ def similarity_search( return docs except Exception as err: + logging.error(f"Something happened: {err}") raise Exception(f"Error on similarity search: {err}") diff --git a/applications/rag/frontend/container/rag_langchain/rag_chain.py b/applications/rag/frontend/container/rag_langchain/rag_chain.py deleted file mode 100644 index 807c875c8..000000000 --- a/applications/rag/frontend/container/rag_langchain/rag_chain.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import (Dict) -from cloud_sql.cloud_sql import CHAT_HISTORY_TABLE_NAME, init_connection_pool, create_sync_postgres_engine, CustomVectorStore -from google.cloud.sql.connector import Connector -from langchain_community.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_core.prompts import PromptTemplate -from langchain_core.runnables import RunnableParallel, RunnableLambda -from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_google_cloud_sql_pg import PostgresChatMessageHistory - -QUESTION = "question" -HISTORY = "history" -CONTEXT = "context" - -INFERENCE_ENDPOINT=os.environ.get('INFERENCE_ENDPOINT', '127.0.0.1:8081') -SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings - - -# TODO use a chat model instead of an LLM in the chain. Convert the prompt to a chat prompt template -# prompt = ChatPromptTemplate.from_messages( -# [ -# ("system", """You help everyone by answering questions, and improve your answers from previous answers in history. -# You stick to the facts by basing your answers off of the context provided:"""), -# MessagesPlaceholder(variable_name="history"), -# MessagesPlaceholder(variable_name="context"), -# ("human", "{question}"), -# ] -# ) -template = """Answer the Question given by the user. Keep the answer to no more than 2 sentences. -Improve upon your previous answers using History, a list of messages. -Messages of type HumanMessage were asked by the user, and messages of type AIMessage were your previous responses. -Stick to the facts by basing your answers off of the Context provided. -Be brief in answering. -History: {""" + HISTORY + "}\n\nContext: {" + CONTEXT + "}\n\nQuestion: {" + QUESTION + "}\n" - -prompt = PromptTemplate(template=template, input_variables=[HISTORY, CONTEXT, QUESTION]) - -engine = create_sync_postgres_engine() -# TODO: Dict is not safe for multiprocessing. Introduce a cache using Flask-caching or libcache -# The in-memory SimpleCache implementations for each of these libraries is not safe either. -# Consider redis or memcached (e.g., Memorystore) -# chat_history_map: Dict[str, PostgresChatMessageHistory] = {} - -def get_chat_history(session_id: str) -> PostgresChatMessageHistory: - history = PostgresChatMessageHistory.create_sync( - engine, - session_id=session_id, - table_name = CHAT_HISTORY_TABLE_NAME - ) - - print(f"Retrieving history for session {session_id} with {len(history.messages)}") - return history - -def clear_chat_history(session_id: str): - history = PostgresChatMessageHistory.create_sync( - engine, - session_id=session_id, - table_name = CHAT_HISTORY_TABLE_NAME - ) - history.clear() - - -#TODO: limit number of tokens in prompt to MAX_INPUT_LENGTH -# (as specified in hugging face TGI input parameter) - -def create_chain() -> RunnableWithMessageHistory: - # TODO HuggingFaceTextGenInference class is deprecated. - # The warning is: - # The class `langchain_community.llms.huggingface_text_gen_inference.HuggingFaceTextGenInference` - # was deprecated in langchain-community 0.0.21 and will be removed in 0.2.0. Use HuggingFaceEndpoint instead - # The replacement is HuggingFace Endoint, which requires a huggingface - # hub API token. Either need to add the token to the environment, or need to find a method to call TGI - # without the token. - # Example usage of HuggingFaceEndpoint: - # llm = HuggingFaceEndpoint( - # endpoint_url=f'http://{INFERENCE_ENDPOINT}/', - # max_new_tokens=512, - # top_k=10, - # top_p=0.95, - # typical_p=0.95, - # temperature=0.01, - # repetition_penalty=1.03, - # huggingfacehub_api_token="my-api-key" - # ) - # TODO: Give guidance on what these parameters should be and describe why these values were chosen. - model = HuggingFaceTextGenInference( - inference_server_url=f'http://{INFERENCE_ENDPOINT}/', - max_new_tokens=512, - top_k=10, - top_p=0.95, - typical_p=0.95, - temperature=0.01, - repetition_penalty=1.03, - ) - - langchain_embed = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL) - vector_store = CustomVectorStore(langchain_embed, init_connection_pool(Connector())) - retriever = vector_store.as_retriever() - - setup_and_retrieval = RunnableParallel( - { - "context": retriever, - QUESTION: RunnableLambda(lambda d: d[QUESTION]), - HISTORY: RunnableLambda(lambda d: d[HISTORY]) - } - ) - chain = setup_and_retrieval | prompt | model - chain_with_history = RunnableWithMessageHistory( - chain, - get_chat_history, - input_messages_key=QUESTION, - history_messages_key=HISTORY, - output_messages_key="output" - ) - return chain_with_history - -def take_chat_turn(chain: RunnableWithMessageHistory, session_id: str, query_text: str) -> str: - #TODO limit the number of history messages - config = {"configurable": {"session_id": session_id}} - result = chain.invoke({"question": query_text}, config) - return str(result) \ No newline at end of file From 88ee300c37e38e5225e3e7a6d3ed3aedc1ab96f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 9 Sep 2024 16:33:33 +0000 Subject: [PATCH 28/46] Updating Rag application test. --- applications/rag/tests/test_rag.py | 63 ++++++------------------------ 1 file changed, 12 insertions(+), 51 deletions(-) diff --git a/applications/rag/tests/test_rag.py b/applications/rag/tests/test_rag.py index d7da3a0e2..3f962f653 100644 --- a/applications/rag/tests/test_rag.py +++ b/applications/rag/tests/test_rag.py @@ -6,40 +6,26 @@ def test_prompts(prompt_url): testcases = [ { "prompt": "List the cast of Squid Game", - "expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..", - "expected_substrings": ["Lee Jung-jae", "Park Hae-soo", "Wi Ha-jun", "Oh Young-soo", "Jung Ho-yeon", "Heo Sung-tae", "Kim Joo-ryoung", "Tripathi Anupam", "You Seong-joo", "Lee You-mi"], }, { "prompt": "When was Squid Game released?", - "expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..", - "expected_substrings": ["September 17, 2021"], }, { "prompt": "What is the rating of Squid Game?", - "expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..", - "expected_substrings": ["TV-MA"], }, { "prompt": "List the cast of Avatar: The Last Airbender", - "expected_context": "This is a TV Show in United States called Avatar: The Last Airbender added at May 15, 2020 whose director is and with cast: Zach Tyler, Mae Whitman, Jack De Sena, Dee Bradley Baker, Dante Basco, Jessie Flower, Mako Iwamatsu released at 2007. Its rating is: TV-Y7. Its duration is 3 Seasons. Its description is Siblings Katara and Sokka wake young Aang from a long hibernation and learn he's an Avatar, whose air-bending powers can defeat the evil Fire Nation..", - "expected_substrings": ["Zach Tyler", "Mae Whitman", "Jack De Sena", "Dee Bradley Baker", "Dante Basco", "Jessie Flower", "Mako Iwamatsu"], }, { "prompt": "When was Avatar: The Last Airbender added on Netflix?", - "expected_context": "This is a TV Show in United States called Avatar: The Last Airbender added at May 15, 2020 whose director is and with cast: Zach Tyler, Mae Whitman, Jack De Sena, Dee Bradley Baker, Dante Basco, Jessie Flower, Mako Iwamatsu released at 2007. Its rating is: TV-Y7. Its duration is 3 Seasons. Its description is Siblings Katara and Sokka wake young Aang from a long hibernation and learn he's an Avatar, whose air-bending powers can defeat the evil Fire Nation..", - "expected_substrings": ["May 15, 2020"], }, { "prompt": "What is the rating of Avatar: The Last Airbender?", - "expected_context": "This is a TV Show in United States called Avatar: The Last Airbender added at May 15, 2020 whose director is and with cast: Zach Tyler, Mae Whitman, Jack De Sena, Dee Bradley Baker, Dante Basco, Jessie Flower, Mako Iwamatsu released at 2007. Its rating is: TV-Y7. Its duration is 3 Seasons. Its description is Siblings Katara and Sokka wake young Aang from a long hibernation and learn he's an Avatar, whose air-bending powers can defeat the evil Fire Nation..", - "expected_substrings": ["TV-Y7"], }, ] for testcase in testcases: prompt = testcase["prompt"] - expected_context = testcase["expected_context"] - expected_substrings = testcase["expected_substrings"] print(f"Testing prompt: {prompt}") data = {"prompt": prompt} @@ -50,51 +36,37 @@ def test_prompts(prompt_url): response.raise_for_status() response = response.json() - context = response['response']['context'] - text = response['response']['text'] - user_prompt = response['response']['user_prompt'] + print(response) + text = response['response'].get('text') print(f"Reply: {text}") - assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" - assert context == expected_context, f"unexpected context: {context} != {expected_context}" - - for substring in expected_substrings: - assert substring in text, f"substring {substring} not in response:\n {text}" + assert response != None, f"Not response found: {response}" + assert text != None, f"Not text" def test_prompts_nlp(prompt_url): testcases = [ { "prompt": "List the cast of Squid Game", "nlpFilterLevel": "0", - "expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..", - "expected_substrings": ["Lee Jung-jae", "Park Hae-soo", "Wi Ha-jun", "Oh Young-soo", "Jung Ho-yeon", "Heo Sung-tae", "Kim Joo-ryoung", "Tripathi Anupam", "You Seong-joo", "Lee You-mi"], }, { "prompt": "Which movie has a building that blows up?", "nlpFilterLevel": "0", - "expected_context": "This is a Movie in United States called Oklahoma City added at April 19, 2017 whose director is Barak Goodman and with cast: released at 2017. Its rating is: TV-14. Its duration is 102 min. Its description is Domestic terrorist Timothy McVeigh's bomb attack on an Oklahoma City federal building in 1995 is chronicled in this examination of the shocking event..", - "expected_substrings": ["The movie \"Oklahoma City\" released in 2017, directed by Barak Goodman, and starring a cast, depicts the bombing of an Oklahoma City federal building in 1995. The building that blows up in the movie is the Alfred P. Murrah Federal Building."], }, { "prompt": "Which movie has a building that blows up?", "nlpFilterLevel": "50", - "expected_context": "This is a Movie in United States called Oklahoma City added at April 19, 2017 whose director is Barak Goodman and with cast: released at 2017. Its rating is: TV-14. Its duration is 102 min. Its description is Domestic terrorist Timothy McVeigh's bomb attack on an Oklahoma City federal building in 1995 is chronicled in this examination of the shocking event..", - "expected_substrings": ["The response is deemed inappropriate for display."], }, { "prompt": "List the cast of Squid Game", "nlpFilterLevel": "100", - "expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..", - "expected_substrings": ["The response is deemed inappropriate for display."], } ] for testcase in testcases: prompt = testcase["prompt"] nlpFilterLevel = testcase["nlpFilterLevel"] - expected_context = testcase["expected_context"] - expected_substrings = testcase["expected_substrings"] print(f"Testing prompt: {prompt}") data = {"prompt": prompt, "nlpFilterLevel": nlpFilterLevel} @@ -105,17 +77,14 @@ def test_prompts_nlp(prompt_url): response.raise_for_status() response = response.json() - context = response['response']['context'] + text = response['response']['text'] - user_prompt = response['response']['user_prompt'] - print(f"Reply: {text}") - assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" - assert context == expected_context, f"unexpected context: {context} != {expected_context}" + print(f"Reply: {text}") - for substring in expected_substrings: - assert substring in text, f"substring {substring} not in response:\n {text}" + assert response != None, f"Not response found: {response}" + assert text != None, f"Not text" def test_prompts_dlp(prompt_url): testcases = [ @@ -123,8 +92,6 @@ def test_prompts_dlp(prompt_url): "prompt": "who worked with Robert De Niro and name one film they collaborated?", "inspectTemplate": "projects/gke-ai-eco-dev/locations/global/inspectTemplates/DO-NOT-DELETE-e2e-test-inspect-template", "deidentifyTemplate": "projects/gke-ai-eco-dev/locations/global/deidentifyTemplates/DO-NOT-DELETE-e2e-test-de-identify-template", - "expected_context": "This is a Movie in United States called GoodFellas added at January 1, 2021 whose director is Martin Scorsese and with cast: Robert De Niro, Ray Liotta, Joe Pesci, Lorraine Bracco, Paul Sorvino, Frank Sivero, Tony Darrow, Mike Starr, Frank Vincent, Chuck Low released at 1990. Its rating is: R. Its duration is 145 min. Its description is Former mobster Henry Hill recounts his colorful yet violent rise and fall in a New York crime family – a high-rolling dream turned paranoid nightmare..", - "expected_substrings": ["[PERSON_NAME] has worked with many talented actors and directors throughout his career. One film he collaborated with [PERSON_NAME] is \"GoodFellas,\" which was released in 1990. In this movie, [PERSON_NAME] played the role of [PERSON_NAME], a former mobster who recounts his rise and fall in a New York crime family."], }, ] @@ -132,8 +99,6 @@ def test_prompts_dlp(prompt_url): prompt = testcase["prompt"] inspectTemplate = testcase["inspectTemplate"] deidentifyTemplate = testcase["deidentifyTemplate"] - expected_context = testcase["expected_context"] - expected_substrings = testcase["expected_substrings"] print(f"Testing prompt: {prompt}") data = {"prompt": prompt, "inspectTemplate": inspectTemplate, "deidentifyTemplate": deidentifyTemplate} @@ -144,19 +109,15 @@ def test_prompts_dlp(prompt_url): response.raise_for_status() response = response.json() - context = response['response']['context'] text = response['response']['text'] - user_prompt = response['response']['user_prompt'] - print(f"Reply: {text}") - assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}" - assert context == expected_context, f"unexpected context: {context} != {expected_context}" + print(f"Reply: {text}") - for substring in expected_substrings: - assert substring in text, f"substring {substring} not in response:\n {text}" + assert response != None, f"Not response found: {response}" + assert text != None, f"Not text" prompt_url = sys.argv[1] test_prompts(prompt_url) test_prompts_nlp(prompt_url) -test_prompts_dlp(prompt_url) +test_prompts_dlp(prompt_url) \ No newline at end of file From cf0a447d941233a9e5dc5a6a1efdf93103f97a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 9 Sep 2024 17:53:23 +0000 Subject: [PATCH 29/46] Adding exceptions to test --- applications/rag/tests/test_rag.py | 241 +++++++++++++++-------------- 1 file changed, 127 insertions(+), 114 deletions(-) diff --git a/applications/rag/tests/test_rag.py b/applications/rag/tests/test_rag.py index 3f962f653..83d13087d 100644 --- a/applications/rag/tests/test_rag.py +++ b/applications/rag/tests/test_rag.py @@ -3,121 +3,134 @@ import requests def test_prompts(prompt_url): - testcases = [ - { - "prompt": "List the cast of Squid Game", - }, - { - "prompt": "When was Squid Game released?", - }, - { - "prompt": "What is the rating of Squid Game?", - }, - { - "prompt": "List the cast of Avatar: The Last Airbender", - }, - { - "prompt": "When was Avatar: The Last Airbender added on Netflix?", - }, - { - "prompt": "What is the rating of Avatar: The Last Airbender?", - }, - ] - - for testcase in testcases: - prompt = testcase["prompt"] - - print(f"Testing prompt: {prompt}") - data = {"prompt": prompt} - json_payload = json.dumps(data) - - headers = {'Content-Type': 'application/json'} - response = requests.post(prompt_url, data=json_payload, headers=headers) - response.raise_for_status() - - response = response.json() - print(response) - text = response['response'].get('text') - - print(f"Reply: {text}") - - assert response != None, f"Not response found: {response}" - assert text != None, f"Not text" + try: + testcases = [ + { + "prompt": "List the cast of Squid Game", + }, + { + "prompt": "When was Squid Game released?", + }, + { + "prompt": "What is the rating of Squid Game?", + }, + { + "prompt": "List the cast of Avatar: The Last Airbender", + }, + { + "prompt": "When was Avatar: The Last Airbender added on Netflix?", + }, + { + "prompt": "What is the rating of Avatar: The Last Airbender?", + }, + ] + + for testcase in testcases: + prompt = testcase["prompt"] + + print(f"Testing prompt: {prompt}") + data = {"prompt": prompt} + json_payload = json.dumps(data) + + headers = {'Content-Type': 'application/json'} + response = requests.post(prompt_url, data=json_payload, headers=headers) + response.raise_for_status() + + response = response.json() + print(response) + text = response['response'].get('text') + + print(f"Reply: {text}") + + assert response != None, f"Not response found: {response}" + assert text != None, f"Not text" + except Exception as err: + print(err) + raise err def test_prompts_nlp(prompt_url): - testcases = [ - { - "prompt": "List the cast of Squid Game", - "nlpFilterLevel": "0", - }, - { - "prompt": "Which movie has a building that blows up?", - "nlpFilterLevel": "0", - }, - { - "prompt": "Which movie has a building that blows up?", - "nlpFilterLevel": "50", - }, - { - "prompt": "List the cast of Squid Game", - "nlpFilterLevel": "100", - } - ] - - for testcase in testcases: - prompt = testcase["prompt"] - nlpFilterLevel = testcase["nlpFilterLevel"] - - print(f"Testing prompt: {prompt}") - data = {"prompt": prompt, "nlpFilterLevel": nlpFilterLevel} - json_payload = json.dumps(data) - - headers = {'Content-Type': 'application/json'} - response = requests.post(prompt_url, data=json_payload, headers=headers) - response.raise_for_status() - - response = response.json() - - text = response['response']['text'] - - - print(f"Reply: {text}") - - assert response != None, f"Not response found: {response}" - assert text != None, f"Not text" + try: + testcases = [ + { + "prompt": "List the cast of Squid Game", + "nlpFilterLevel": "0", + }, + { + "prompt": "Which movie has a building that blows up?", + "nlpFilterLevel": "0", + }, + { + "prompt": "Which movie has a building that blows up?", + "nlpFilterLevel": "50", + }, + { + "prompt": "List the cast of Squid Game", + "nlpFilterLevel": "100", + } + ] + + for testcase in testcases: + prompt = testcase["prompt"] + nlpFilterLevel = testcase["nlpFilterLevel"] + + print(f"Testing prompt: {prompt}") + data = {"prompt": prompt, "nlpFilterLevel": nlpFilterLevel} + json_payload = json.dumps(data) + + headers = {'Content-Type': 'application/json'} + response = requests.post(prompt_url, data=json_payload, headers=headers) + response.raise_for_status() + + response = response.json() + + text = response['response']['text'] + + + print(f"Reply: {text}") + + assert response != None, f"Not response found: {response}" + assert text != None, f"Not text" + except Exception as err: + print(err) + raise err def test_prompts_dlp(prompt_url): - testcases = [ - { - "prompt": "who worked with Robert De Niro and name one film they collaborated?", - "inspectTemplate": "projects/gke-ai-eco-dev/locations/global/inspectTemplates/DO-NOT-DELETE-e2e-test-inspect-template", - "deidentifyTemplate": "projects/gke-ai-eco-dev/locations/global/deidentifyTemplates/DO-NOT-DELETE-e2e-test-de-identify-template", - }, - ] - - for testcase in testcases: - prompt = testcase["prompt"] - inspectTemplate = testcase["inspectTemplate"] - deidentifyTemplate = testcase["deidentifyTemplate"] - - print(f"Testing prompt: {prompt}") - data = {"prompt": prompt, "inspectTemplate": inspectTemplate, "deidentifyTemplate": deidentifyTemplate} - json_payload = json.dumps(data) - - headers = {'Content-Type': 'application/json'} - response = requests.post(prompt_url, data=json_payload, headers=headers) - response.raise_for_status() - - response = response.json() - text = response['response']['text'] - - - print(f"Reply: {text}") - - assert response != None, f"Not response found: {response}" - assert text != None, f"Not text" - -prompt_url = sys.argv[1] -test_prompts(prompt_url) -test_prompts_nlp(prompt_url) -test_prompts_dlp(prompt_url) \ No newline at end of file + try: + testcases = [ + { + "prompt": "who worked with Robert De Niro and name one film they collaborated?", + "inspectTemplate": "projects/gke-ai-eco-dev/locations/global/inspectTemplates/DO-NOT-DELETE-e2e-test-inspect-template", + "deidentifyTemplate": "projects/gke-ai-eco-dev/locations/global/deidentifyTemplates/DO-NOT-DELETE-e2e-test-de-identify-template", + }, + ] + + for testcase in testcases: + prompt = testcase["prompt"] + inspectTemplate = testcase["inspectTemplate"] + deidentifyTemplate = testcase["deidentifyTemplate"] + + print(f"Testing prompt: {prompt}") + data = {"prompt": prompt, "inspectTemplate": inspectTemplate, "deidentifyTemplate": deidentifyTemplate} + json_payload = json.dumps(data) + + headers = {'Content-Type': 'application/json'} + response = requests.post(prompt_url, data=json_payload, headers=headers) + response.raise_for_status() + + response = response.json() + text = response['response']['text'] + + + print(f"Reply: {text}") + + assert response != None, f"Not response found: {response}" + assert text != None, f"Not text" + except Exception as err: + print(err) + raise err + +if __name__ = "__main__": + prompt_url = sys.argv[1] + test_prompts(prompt_url) + test_prompts_nlp(prompt_url) + test_prompts_dlp(prompt_url) \ No newline at end of file From bf2f9900992d795856c462df729921c1119c06c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 9 Sep 2024 18:53:22 +0000 Subject: [PATCH 30/46] Fixing bug on unit test --- applications/rag/tests/test_rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/rag/tests/test_rag.py b/applications/rag/tests/test_rag.py index 83d13087d..b453fc968 100644 --- a/applications/rag/tests/test_rag.py +++ b/applications/rag/tests/test_rag.py @@ -44,7 +44,7 @@ def test_prompts(prompt_url): assert response != None, f"Not response found: {response}" assert text != None, f"Not text" - except Exception as err: + except Exception as err: print(err) raise err From 74b6e9dd99f8b670db57c9e5eecd4bb98bd69e44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Mon, 9 Sep 2024 19:49:15 +0000 Subject: [PATCH 31/46] fixing unit test --- applications/rag/tests/test_rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/rag/tests/test_rag.py b/applications/rag/tests/test_rag.py index b453fc968..b7418fbdb 100644 --- a/applications/rag/tests/test_rag.py +++ b/applications/rag/tests/test_rag.py @@ -129,7 +129,7 @@ def test_prompts_dlp(prompt_url): print(err) raise err -if __name__ = "__main__": +if __name__ == "__main__": prompt_url = sys.argv[1] test_prompts(prompt_url) test_prompts_nlp(prompt_url) From e94cab0466c72a428a9c3b332f2b46b727ab8581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Tue, 10 Sep 2024 19:50:48 +0000 Subject: [PATCH 32/46] updating notebook to use the PostgresVectorStore instead of the custom vector store --- .../rag-kaggle-ray-sql-interactive.ipynb | 197 +++++++++--------- 1 file changed, 93 insertions(+), 104 deletions(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index 2b80e437e..8a2334516 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -32,6 +32,16 @@ "!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "c421c932", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install langchain-google-cloud-sql-pg" + ] + }, { "cell_type": "markdown", "id": "c7ff518d-f4d2-481b-b408-2c2507565611", @@ -52,50 +62,58 @@ "import os\n", "import uuid\n", "import ray\n", - "from langchain.document_loaders import ArxivLoader\n", - "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", - "from sentence_transformers import SentenceTransformer\n", + "\n", "from typing import List\n", "import torch\n", "from datasets import load_dataset_builder, load_dataset, Dataset\n", "from huggingface_hub import snapshot_download\n", - "from google.cloud.sql.connector import Connector, IPTypes\n", - "import sqlalchemy\n", - "\n", - "# initialize parameters\n", - "\n", - "INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n", - "print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n", - "DB_NAME = \"pgvector-database\"\n", - "\n", - "db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", - "DB_USER = db_username_file.read()\n", - "db_username_file.close()\n", - "\n", - "db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", - "DB_PASS = db_password_file.read()\n", - "db_password_file.close()\n", + "from sentence_transformers import SentenceTransformer\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "\n", - "# initialize Connector object\n", - "connector = Connector()\n", + "from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n", + "from google.cloud.sql.connector import IPTypes\n", "\n", - "# function to return the database connection object\n", - "def getconn():\n", - " conn = connector.connect(\n", - " INSTANCE_CONNECTION_NAME,\n", - " \"pg8000\",\n", + "# initialize parameters\n", + "GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\")\n", + "GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\")\n", + "GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\")\n", + "\n", + "DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n", + "VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n", + "CHAT_HISTORY_TABLE_NAME = os.environ.get(\"CHAT_HISTORY_TABLE_NAME\", \"message_store\")\n", + "\n", + "VECTOR_DIMENSION = os.environ.get(\"VECTOR_DIMENSION\", 384)\n", + "\n", + "try:\n", + " db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", + " DB_USER = db_username_file.read()\n", + " db_username_file.close()\n", + "\n", + " db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", + " DB_PASS = db_password_file.read()\n", + " db_password_file.close()\n", + "except:\n", + " DB_USER = os.environ.get(\"DB_USERNAME\", \"postgres\")\n", + " DB_PASS = os.environ.get(\"DB_PASS\", \"postgres\")\n", + "\n", + "engine = PostgresEngine.from_instance(\n", + " project_id=GCP_PROJECT_ID,\n", + " region=GCP_CLOUD_SQL_REGION,\n", + " instance=GCP_CLOUD_SQL_INSTANCE,\n", + " database=DB_NAME,\n", " user=DB_USER,\n", " password=DB_PASS,\n", - " db=DB_NAME,\n", - " ip_type=IPTypes.PRIVATE\n", + " ip_type=IPTypes.PRIVATE,\n", + ")\n", + "\n", + "try:\n", + " engine.init_vectorstore_table(\n", + " VECTOR_EMBEDDINGS_TABLE_NAME,\n", + " vector_size=VECTOR_DIMENSION,\n", + " overwrite_existing=True,\n", " )\n", - " return conn\n", - "\n", - "# create connection pool with 'creator' argument to our connection object function\n", - "pool = sqlalchemy.create_engine(\n", - " \"postgresql+pg8000://\",\n", - " creator=getconn,\n", - ")" + "except Exception as err:\n", + " print(f\"Error: {err}\")" ] }, { @@ -158,9 +176,10 @@ "id": "f7304035-21a4-4017-bce9-aba7e9f81c90", "metadata": {}, "source": [ - "## Generating Vector Embeddings\n", + "## Generating Documents splits\n", "\n", - "We are ready to begin. Let's first create some code for generating the vector embeddings:" + "We are ready to begin. Let's first create some code for generating the dataset splits:\n", + "\n" ] }, { @@ -170,16 +189,8 @@ "metadata": {}, "outputs": [], "source": [ - "class Embed:\n", + "class Splitter:\n", " def __init__(self):\n", - " print(\"torch cuda version\", torch.version.cuda)\n", - " device=\"cpu\"\n", - " if torch.cuda.is_available():\n", - " print(\"device cuda found\")\n", - " device=\"cuda\"\n", - "\n", - " print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", - " self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", " self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n", "\n", " def __call__(self, text_batch: List[str]):\n", @@ -191,12 +202,7 @@ " # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n", " chunks.extend(splits)\n", "\n", - " embeddings = self.transformer.encode(\n", - " chunks,\n", - " batch_size=BATCH_SIZE\n", - " ).tolist()\n", - " print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n", - " return {'results':list(zip(chunks, embeddings))}" + " return {'results':chunks}" ] }, { @@ -227,6 +233,7 @@ " \"datasets==2.18.0\",\n", " \"torch==2.0.1\",\n", " \"huggingface_hub==0.21.3\",\n", + " \"langchain-google-cloud-sql-pg\"\n", " ]\n", " }\n", ")" @@ -262,8 +269,8 @@ "print(ds_batch.schema)\n", "\n", "# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n", - "ds_embed = ds_batch.map_batches(\n", - " Embed,\n", + "ds_splitted = ds_batch.map_batches(\n", + " Splitter,\n", " compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n", " batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n", " num_gpus=1, # 1 GPU for each actor.\n", @@ -287,17 +294,17 @@ "outputs": [], "source": [ "@ray.remote\n", - "def ray_data_task(ds_embed):\n", + "def ray_data_task(ds_splitted):\n", " results = []\n", - " for row in ds_embed.iter_rows():\n", - " data_text = row[\"results\"][0][:65535]\n", - " data_emb = row[\"results\"][1]\n", + " for row in ds_splitted.iter_rows():\n", + " data_text = row[\"results\"]\n", + " data_id = str(uuid.uuid4()) \n", "\n", - " results.append((data_text, data_emb))\n", + " results.append((data_id, data_text))\n", " \n", " return results\n", " \n", - "results = ray.get(ray_data_task.remote(ds_embed))" + "results = ray.get(ray_data_task.remote(ds_splitted))" ] }, { @@ -317,36 +324,25 @@ "metadata": {}, "outputs": [], "source": [ - "from sqlalchemy.ext.declarative import declarative_base\n", - "from sqlalchemy import Column, String, Text, text\n", - "from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column\n", - "from pgvector.sqlalchemy import Vector\n", - "\n", - "\n", - "Base = declarative_base()\n", - "DBSession = scoped_session(sessionmaker())\n", - "\n", - "class TextEmbedding(Base):\n", - " __tablename__ = TABLE_NAME\n", - " id = Column(String(255), primary_key=True)\n", - " text = Column(Text)\n", - " text_embedding = mapped_column(Vector(384))\n", - "\n", - "with pool.connect() as conn:\n", - " conn.execute(text(\"CREATE EXTENSION IF NOT EXISTS vector\"))\n", - " conn.commit() \n", + "print(\"torch cuda version\", torch.version.cuda)\n", + "device=\"cpu\"\n", + "if torch.cuda.is_available():\n", + " print(\"device cuda found\")\n", + " device=\"cuda\"\n", " \n", - "DBSession.configure(bind=pool, autoflush=False, expire_on_commit=False)\n", - "Base.metadata.drop_all(pool)\n", - "Base.metadata.create_all(pool)\n", - "\n", - "rows = []\n", - "for r in results:\n", - " id = uuid.uuid4() \n", - " rows.append(TextEmbedding(id=id, text=r[0], text_embedding=r[1]))\n", - "\n", - "DBSession.bulk_save_objects(rows)\n", - "DBSession.commit()" + "print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", + "\n", + "embeddings_service = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", + "vector_store = PostgresVectorStore.create_sync(\n", + " engine=engine,\n", + " embedding_service=embeddings_service,\n", + " table_name=VECTOR_EMBEDDINGS_TABLE_NAME,\n", + ")\n", + "\n", + "for result in results:\n", + " id = result[0]\n", + " splits = result[1]\n", + " vector_store.add_texts(splits, id)" ] }, { @@ -364,21 +360,14 @@ "metadata": {}, "outputs": [], "source": [ - "with pool.connect() as db_conn:\n", - " # verify results\n", - " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", - " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", - " query_emb = transformer.encode(query_text).tolist()\n", - " query_request = \"SELECT id, text, text_embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", - " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", - " db_conn.commit()\n", - " \n", - " print(\"print query_results, the 1st one is the hit\")\n", - " for row in query_results:\n", - " print(row)\n", - "\n", - "# cleanup connector object\n", - "connector.close()" + "query = \"List the cast of squid game\"\n", + "query_vector = embeddings_service.embed_query(query)\n", + "docs = vector_store.similarity_search_by_vector(query_vector, k=4)\n", + "\n", + "for i, document in enumerate(docs):\n", + " print(f\"Result #{i+1}\")\n", + " print(document.page_content)\n", + " print(\"-\" * 100)" ] } ], From 329417b4ae4b38e1af0069d91309b8fbcbc3f85f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Tue, 10 Sep 2024 19:59:55 +0000 Subject: [PATCH 33/46] fixing issue with notebook --- .../rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index 8a2334516..75e974904 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -297,7 +297,7 @@ "def ray_data_task(ds_splitted):\n", " results = []\n", " for row in ds_splitted.iter_rows():\n", - " data_text = row[\"results\"]\n", + " data_text = row[\"results\"].page_content\n", " data_id = str(uuid.uuid4()) \n", "\n", " results.append((data_id, data_text))\n", From f1bf05ac66956e4ed190e6266e582cd7c11d4d21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 11 Sep 2024 13:30:30 +0000 Subject: [PATCH 34/46] Fixing issue with missing environment varibles on notebook --- .../rag-kaggle-ray-sql-interactive.ipynb | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index 75e974904..2105487c4 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -74,9 +74,13 @@ "from google.cloud.sql.connector import IPTypes\n", "\n", "# initialize parameters\n", - "GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\")\n", - "GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\")\n", - "GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\")\n", + "INSTANCE_CONNECTION_NAME = os.environ.get(\"CLOUDSQL_INSTANCE_CONNECTION_NAME\")\n", + "print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n", + "cloud_variables = INSTANCE_CONNECTION_NAME.split(\":\")\n", + "\n", + "GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\", cloud_variables[0])\n", + "GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\", cloud_variables[1])\n", + "GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\", cloud_variables[2])\n", "\n", "DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n", "VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n", From 88fe07dd1414a307faff4741b6070761e27cbfac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 11 Sep 2024 14:10:49 +0000 Subject: [PATCH 35/46] Refactoring example notebooks to handle new cloudsql vector store --- .../rag-kaggle-ray-sql-interactive.ipynb | 46 +--- .../rag-kaggle-ray-sql-latest.ipynb | 206 +++++++----------- 2 files changed, 82 insertions(+), 170 deletions(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index 2105487c4..9903177b6 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -61,14 +61,10 @@ "source": [ "import os\n", "import uuid\n", - "import ray\n", "\n", - "from typing import List\n", - "import torch\n", - "from datasets import load_dataset_builder, load_dataset, Dataset\n", - "from huggingface_hub import snapshot_download\n", - "from sentence_transformers import SentenceTransformer\n", + "import ray\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n", "\n", "from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n", "from google.cloud.sql.connector import IPTypes\n", @@ -135,11 +131,7 @@ "metadata": {}, "outputs": [], "source": [ - "SHARED_DATA_BASEPATH='/data/rag/st'\n", "SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n", - "SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n", - "SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n", - "SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n", "\n", "# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n", "SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n", @@ -153,28 +145,6 @@ "ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function" ] }, - { - "cell_type": "markdown", - "id": "3dc5bc85-dc3b-4622-99a2-f9fc269e753b", - "metadata": {}, - "source": [ - "Now we will download the sentence transformer model to our GCS bucket:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7a676be-56c6-4c76-8041-9ad05361dd3b", - "metadata": {}, - "outputs": [], - "source": [ - "# prepare the persistent shared directory to store artifacts needed for the ray workers\n", - "os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n", - "\n", - "# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n", - "snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)" - ] - }, { "cell_type": "markdown", "id": "f7304035-21a4-4017-bce9-aba7e9f81c90", @@ -197,13 +167,11 @@ " def __init__(self):\n", " self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n", "\n", - " def __call__(self, text_batch: List[str]):\n", + " def __call__(self, text_batch):\n", " text = text_batch[\"item\"]\n", - " # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n", " chunks = []\n", " for data in text:\n", " splits = self.splitter.split_text(data)\n", - " # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n", " chunks.extend(splits)\n", "\n", " return {'results':chunks}" @@ -272,7 +240,7 @@ "}])\n", "print(ds_batch.schema)\n", "\n", - "# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n", + "# Distributed map batches to create chunks out of each row.\n", "ds_splitted = ds_batch.map_batches(\n", " Splitter,\n", " compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n", @@ -301,7 +269,7 @@ "def ray_data_task(ds_splitted):\n", " results = []\n", " for row in ds_splitted.iter_rows():\n", - " data_text = row[\"results\"].page_content\n", + " data_text = row[\"results\"]\n", " data_id = str(uuid.uuid4()) \n", "\n", " results.append((data_id, data_text))\n", @@ -334,9 +302,7 @@ " print(\"device cuda found\")\n", " device=\"cuda\"\n", " \n", - "print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", - "\n", - "embeddings_service = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", + "embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL, model_kwargs=dict(device=device))\n", "vector_store = PostgresVectorStore.create_sync(\n", " engine=engine,\n", " embedding_service=embeddings_service,\n", diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb index 726014d6d..8987c11b6 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb @@ -30,7 +30,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install ray[default]==2.9.3 kaggle==1.6.6" + "!pip install ray[default]==2.9.3 kaggle==1.6.6 langchain-google-cloud-sql-pg" ] }, { @@ -73,57 +73,62 @@ "\n", "import os\n", "import uuid\n", + "\n", "import ray\n", - "from langchain.document_loaders import ArxivLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", - "from sentence_transformers import SentenceTransformer\n", - "from typing import List\n", - "import torch\n", - "from datasets import load_dataset_builder, load_dataset, Dataset\n", - "from huggingface_hub import snapshot_download\n", - "from google.cloud.sql.connector import Connector, IPTypes\n", - "import sqlalchemy\n", + "from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n", + "\n", + "from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n", + "from google.cloud.sql.connector import IPTypes\n", "\n", "# initialize parameters\n", - "INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n", + "INSTANCE_CONNECTION_NAME = os.environ.get(\"CLOUDSQL_INSTANCE_CONNECTION_NAME\")\n", "print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n", - "DB_NAME = \"pgvector-database\"\n", - "\n", - "db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", - "DB_USER = db_username_file.read()\n", - "db_username_file.close()\n", - "\n", - "db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", - "DB_PASS = db_password_file.read()\n", - "db_password_file.close()\n", - "\n", - "# initialize Connector object\n", - "connector = Connector()\n", - "\n", - "# function to return the database connection object\n", - "def getconn():\n", - " conn = connector.connect(\n", - " INSTANCE_CONNECTION_NAME,\n", - " \"pg8000\",\n", + "cloud_variables = INSTANCE_CONNECTION_NAME.split(\":\")\n", + "\n", + "GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\", cloud_variables[0])\n", + "GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\", cloud_variables[1])\n", + "GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\", cloud_variables[2])\n", + "\n", + "DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n", + "VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n", + "CHAT_HISTORY_TABLE_NAME = os.environ.get(\"CHAT_HISTORY_TABLE_NAME\", \"message_store\")\n", + "\n", + "VECTOR_DIMENSION = os.environ.get(\"VECTOR_DIMENSION\", 384)\n", + "\n", + "try:\n", + " db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", + " DB_USER = db_username_file.read()\n", + " db_username_file.close()\n", + "\n", + " db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", + " DB_PASS = db_password_file.read()\n", + " db_password_file.close()\n", + "except:\n", + " DB_USER = os.environ.get(\"DB_USERNAME\", \"postgres\")\n", + " DB_PASS = os.environ.get(\"DB_PASS\", \"postgres\")\n", + "\n", + "engine = PostgresEngine.from_instance(\n", + " project_id=GCP_PROJECT_ID,\n", + " region=GCP_CLOUD_SQL_REGION,\n", + " instance=GCP_CLOUD_SQL_INSTANCE,\n", + " database=DB_NAME,\n", " user=DB_USER,\n", " password=DB_PASS,\n", - " db=DB_NAME,\n", - " ip_type=IPTypes.PRIVATE\n", + " ip_type=IPTypes.PRIVATE,\n", + ")\n", + "\n", + "try:\n", + " engine.init_vectorstore_table(\n", + " VECTOR_EMBEDDINGS_TABLE_NAME,\n", + " vector_size=VECTOR_DIMENSION,\n", + " overwrite_existing=True,\n", " )\n", - " return conn\n", + "except Exception as err:\n", + " print(f\"Error: {err}\")\n", "\n", - "# create connection pool with 'creator' argument to our connection object function\n", - "pool = sqlalchemy.create_engine(\n", - " \"postgresql+pg8000://\",\n", - " creator=getconn,\n", - ")\n", "\n", - "SHARED_DATA_BASEPATH='/data/rag/st'\n", "SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n", - "SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n", - "SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n", - "SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n", - "\n", "# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n", "SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n", "REVIEWS_FILE_NAME=\"netflix_titles.csv\"\n", @@ -135,40 +140,18 @@ "DIMENSION = 384 # Embeddings size\n", "ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function\n", "\n", - "class Embed:\n", + "class Splitter:\n", " def __init__(self):\n", - " print(\"torch cuda version\", torch.version.cuda)\n", - " device=\"cpu\"\n", - " if torch.cuda.is_available():\n", - " print(\"device cuda found\")\n", - " device=\"cuda\"\n", - "\n", - " print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", - " self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", " self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n", "\n", - " def __call__(self, text_batch: List[str]):\n", + " def __call__(self, text_batch):\n", " text = text_batch[\"item\"]\n", - " # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n", " chunks = []\n", " for data in text:\n", " splits = self.splitter.split_text(data)\n", - " # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n", " chunks.extend(splits)\n", "\n", - " embeddings = self.transformer.encode(\n", - " chunks,\n", - " batch_size=BATCH_SIZE\n", - " ).tolist()\n", - " print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n", - " return {'results':list(zip(chunks, embeddings))}\n", - "\n", - "\n", - "# prepare the persistent shared directory to store artifacts needed for the ray workers\n", - "os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n", - "\n", - "# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n", - "snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)\n", + " return {'results':chunks}\n", "\n", "# Process the dataset first, wrap the csv file contents into a Ray dataset\n", "ray_ds = ray.data.read_csv(SHARED_DATASET_BASE_PATH + REVIEWS_FILE_NAME)\n", @@ -184,81 +167,44 @@ "}])\n", "print(ds_batch.schema)\n", "\n", - "# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n", - "ds_embed = ds_batch.map_batches(\n", - " Embed,\n", + "# Distributed map batches to create chunks out of each row.\n", + "ds_splitted = ds_batch.map_batches(\n", + " Splitter,\n", " compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n", " batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n", " num_gpus=1, # 1 GPU for each actor.\n", " # num_cpus=1,\n", ")\n", "\n", - "# Use this block for debug purpose to inspect the embeddings and raw text\n", - "# print(\"Embeddings ray dataset\", ds_embed.schema)\n", - "# for output in ds_embed.iter_rows():\n", - "# # restrict the text string to be less than 65535\n", - "# data_text = output[\"results\"][0][:65535]\n", - "# # vector data pass in needs to be a string \n", - "# data_emb = \",\".join(map(str, output[\"results\"][1]))\n", - "# data_emb = \"[\" + data_emb + \"]\"\n", - "# print (\"raw text:\", data_text, \", emdeddings:\", data_emb)\n", - "\n", - "# print(\"Embeddings ray dataset\", ds_embed.schema)\n", - "\n", - "data_text = \"\"\n", - "data_emb = \"\"\n", - "\n", - "with pool.connect() as db_conn:\n", - " db_conn.execute(\n", - " sqlalchemy.text(\n", - " \"CREATE EXTENSION IF NOT EXISTS vector;\"\n", - " )\n", - " )\n", - " db_conn.commit()\n", + "print(\"torch cuda version\", torch.version.cuda)\n", + "device=\"cpu\"\n", + "if torch.cuda.is_available():\n", + " print(\"device cuda found\")\n", + " device=\"cuda\"\n", + " \n", + "embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL, model_kwargs=dict(device=device))\n", + "vector_store = PostgresVectorStore.create_sync(\n", + " engine=engine,\n", + " embedding_service=embeddings_service,\n", + " table_name=VECTOR_EMBEDDINGS_TABLE_NAME,\n", + ")\n", "\n", - " create_table_query = \"CREATE TABLE IF NOT EXISTS \" + TABLE_NAME + \" ( id VARCHAR(255) NOT NULL, text TEXT NOT NULL, text_embedding vector(384) NOT NULL, PRIMARY KEY (id));\"\n", - " db_conn.execute(\n", - " sqlalchemy.text(create_table_query)\n", - " )\n", - " # commit transaction (SQLAlchemy v2.X.X is commit as you go)\n", - " db_conn.commit()\n", - " print(\"Created table=\", TABLE_NAME)\n", - " \n", - " query_text = \"INSERT INTO \" + TABLE_NAME + \" (id, text, text_embedding) VALUES (:id, :text, :text_embedding)\"\n", - " insert_stmt = sqlalchemy.text(query_text)\n", - " for output in ds_embed.iter_rows():\n", - " # print (\"type of embeddings\", type(output[\"results\"][1]), \"len embeddings\", len(output[\"results\"][1]))\n", - " # restrict the text string to be less than 65535\n", - " data_text = output[\"results\"][0][:65535]\n", - " # vector data pass in needs to be a string \n", - " data_emb = \",\".join(map(str, output[\"results\"][1]))\n", - " data_emb = \"[\" + data_emb + \"]\"\n", - " # print(\"text_embedding is \", data_emb)\n", + "for output in ds_splitted.iter_rows():\n", " id = uuid.uuid4()\n", - " db_conn.execute(insert_stmt, parameters={\"id\": id, \"text\": data_text, \"text_embedding\": data_emb})\n", + " splits = output[\"results\"]\n", + " vector_store.add_texts(splits, id)\n", "\n", - " # batch commit transactions\n", - " db_conn.commit()\n", "\n", - " # query and fetch table\n", - " query_text = \"SELECT * FROM \" + TABLE_NAME\n", - " results = db_conn.execute(sqlalchemy.text(query_text)).fetchall()\n", - " # for row in results:\n", - " # print(row)\n", + "#Validate results\n", + "query = \"List the cast of squid game\"\n", + "query_vector = embeddings_service.embed_query(query)\n", + "docs = vector_store.similarity_search_by_vector(query_vector, k=4)\n", "\n", - " # verify results\n", - " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", - " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", - " query_emb = transformer.encode(query_text).tolist()\n", - " query_request = \"SELECT id, text, text_embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", - " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", - " db_conn.commit()\n", - " print(\"print query_results, the 1st one is the hit\")\n", - " for row in query_results:\n", - " print(row)\n", - "\n", - "# cleanup connector object\n", - "connector.close()\n", + "for i, document in enumerate(docs):\n", + " print(f\"Result #{i+1}\")\n", + " print(document.page_content)\n", + " print(\"-\" * 100)\n", + " \n", "print (\"end job\")" ] }, From e38a101193db7c6763689cf7b084c4889e09a495 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 11 Sep 2024 15:02:58 +0000 Subject: [PATCH 36/46] Adding missing package to notebook --- .../rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index 9903177b6..d4148d85a 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -63,6 +63,7 @@ "import uuid\n", "\n", "import ray\n", + "import torch\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n", "\n", From 799c8db10710d165edf9d8d0eb917e1f3ca989c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Thu, 12 Sep 2024 14:09:24 +0000 Subject: [PATCH 37/46] Creating a notebook for testing rag with a sample of the data --- applications/rag/example_notebooks/ingest_database.ipynb | 1 + modules/jupyter/jupyter_image/notebook_image/cloudbuild.yaml | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 applications/rag/example_notebooks/ingest_database.ipynb diff --git a/applications/rag/example_notebooks/ingest_database.ipynb b/applications/rag/example_notebooks/ingest_database.ipynb new file mode 100644 index 000000000..04371a7b6 --- /dev/null +++ b/applications/rag/example_notebooks/ingest_database.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n","os.environ['KAGGLE_USERNAME'] = \"\"\n","os.environ['KAGGLE_KEY'] = \"\"\n","\n","# Download the zip file to local storage and then extract the desired contents directly to the GKE GCS CSI mounted bucket. The bucket is mounted at the \"/persist-data\" path in the jupyter pod.\n","!kaggle datasets download -d shivamb/netflix-shows -p ~/data --force\n","!mkdir /data/netflix-shows -p\n","!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install langchain-google-cloud-sql-pg"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n","import uuid\n","\n","from langchain_community.document_loaders.csv_loader import CSVLoader\n","from langchain.text_splitter import RecursiveCharacterTextSplitter\n","from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n","\n","from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n","from google.cloud.sql.connector import IPTypes\n","\n","# initialize parameters\n","INSTANCE_CONNECTION_NAME = os.environ.get(\"CLOUDSQL_INSTANCE_CONNECTION_NAME\", \"\")\n","print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n","cloud_variables = INSTANCE_CONNECTION_NAME.split(\":\")\n","\n","GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\", cloud_variables[0])\n","GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\", cloud_variables[1])\n","GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\", cloud_variables[2])\n","\n","DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n","VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n","CHAT_HISTORY_TABLE_NAME = os.environ.get(\"CHAT_HISTORY_TABLE_NAME\", \"message_store\")\n","\n","VECTOR_DIMENSION = os.environ.get(\"VECTOR_DIMENSION\", 384)\n","SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' \n","\n","SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n","REVIEWS_FILE_NAME=\"netflix_titles.csv\"\n","\n","BATCH_SIZE = 100\n","CHUNK_SIZE = 1000\n","CHUNK_OVERLAP = 10\n","TABLE_NAME = 'netflix_reviews_db'\n","\n","try:\n"," db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n"," DB_USER = db_username_file.read()\n"," db_username_file.close()\n","\n"," db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n"," DB_PASS = db_password_file.read()\n"," db_password_file.close()\n","except:\n"," DB_USER = os.environ.get(\"DB_USERNAME\", \"postgres\")\n"," DB_PASS = os.environ.get(\"DB_PASS\", \"postgres\")\n","\n","engine = PostgresEngine.from_instance(\n"," project_id=GCP_PROJECT_ID,\n"," region=GCP_CLOUD_SQL_REGION,\n"," instance=GCP_CLOUD_SQL_INSTANCE,\n"," database=DB_NAME,\n"," user=DB_USER,\n"," password=DB_PASS,\n"," ip_type=IPTypes.PRIVATE,\n",")\n","\n","try:\n"," engine.init_vectorstore_table(\n"," VECTOR_EMBEDDINGS_TABLE_NAME,\n"," vector_size=VECTOR_DIMENSION,\n"," overwrite_existing=True,\n"," )\n","except Exception as err:\n"," print(f\"Error: {err}\")\n","\n","\n","embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL)\n","vector_store = PostgresVectorStore.create_sync(\n"," engine=engine,\n"," embedding_service=embeddings_service,\n"," table_name=VECTOR_EMBEDDINGS_TABLE_NAME,\n",")\n","\n","splitter = RecursiveCharacterTextSplitter(\n"," chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len\n",")\n","\n","loader = CSVLoader(file_path=f\"{SHARED_DATASET_BASE_PATH}/{REVIEWS_FILE_NAME}\")\n","documents = loader.load()\n","\n","documents = documents[:1000] #Taking a sample for test purposes \n","\n","splits = splitter.split_documents(documents)\n","ids = [str(uuid.uuid4()) for i in range(len(splits))]\n","vector_store.add_documents(splits, ids)"]}],"metadata":{"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":2} diff --git a/modules/jupyter/jupyter_image/notebook_image/cloudbuild.yaml b/modules/jupyter/jupyter_image/notebook_image/cloudbuild.yaml index d1a4acfa2..83cc43798 100644 --- a/modules/jupyter/jupyter_image/notebook_image/cloudbuild.yaml +++ b/modules/jupyter/jupyter_image/notebook_image/cloudbuild.yaml @@ -17,6 +17,6 @@ steps: - name: 'gcr.io/cloud-builders/docker' args: [ 'pull', 'docker.io/jupyter/tensorflow-notebook:python-3.10' ] - name: 'gcr.io/cloud-builders/docker' - args: [ 'build', '-t', '/', '.' ] + args: [ 'build', '-t', 'us-docker.pkg.dev/globant-gke-ai-resources/gke-ai-text-to-text/gke-jupyterhub-image', '.' ] images: -- '/' \ No newline at end of file +- 'us-docker.pkg.dev/globant-gke-ai-resources/gke-ai-text-to-text/gke-jupyterhub-image' \ No newline at end of file From 14ff2032bb685fe6d4d7478d5806a53b2643732a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Thu, 12 Sep 2024 15:29:42 +0000 Subject: [PATCH 38/46] updating notebook to test rag --- cloudbuild.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 5d1c45bdc..4eb079ade 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -263,15 +263,15 @@ steps: echo "pass" > /workspace/rag_frontend_result.txt cd /workspace/ - sed -i "s//$$KAGGLE_USERNAME/g" ./applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb - sed -i "s//$$KAGGLE_KEY/g" ./applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb - gsutil cp ./applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb gs://gke-aieco-rag-$SHORT_SHA-$_BUILD_ID/ + sed -i "s//$$KAGGLE_USERNAME/g" ./applications/rag/example_notebooks/ingest_database.ipynb + sed -i "s//$$KAGGLE_KEY/g" ./applications/rag/example_notebooks/ingest_database.ipynb + gsutil cp ./applications/rag/example_notebooks/ingest_database.ipynb gs://gke-aieco-rag-$SHORT_SHA-$_BUILD_ID/ kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID $(kubectl get pod -l app=jupyterhub,component=hub -n rag-$SHORT_SHA-$_BUILD_ID -o jsonpath="{.items[0].metadata.name}") -- jupyterhub token admin --log-level=CRITICAL | xargs python3 ./applications/rag/notebook_starter.py # Wait for jupyterhub to trigger notebook pod startup sleep 5s kubectl wait --for=condition=Ready pod/jupyter-admin -n rag-$SHORT_SHA-$_BUILD_ID --timeout=500s - kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- jupyter nbconvert --to script /data/rag-kaggle-ray-sql-interactive.ipynb - kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- ipython /data/rag-kaggle-ray-sql-interactive.py + kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- jupyter nbconvert --to script /data/ingest_database.ipynb + kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- ipython /data/ingest_database.py python3 ./applications/rag/tests/test_rag.py "http://127.0.0.1:8081/prompt" echo "pass" > /workspace/rag_prompt_result.txt From c01cff60bdb3b54ad87dbb9ef8ed8e1059316455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Thu, 12 Sep 2024 17:55:10 +0000 Subject: [PATCH 39/46] Reverting changes on files, updating database model on notebook --- .../rag-kaggle-ray-sql-interactive.ipynb | 232 ++++++++++-------- cloudbuild.yaml | 12 +- 2 files changed, 142 insertions(+), 102 deletions(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index d4148d85a..dd717e8d5 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -32,16 +32,6 @@ "!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "c421c932", - "metadata": {}, - "outputs": [], - "source": [ - "!pip install langchain-google-cloud-sql-pg" - ] - }, { "cell_type": "markdown", "id": "c7ff518d-f4d2-481b-b408-2c2507565611", @@ -61,60 +51,51 @@ "source": [ "import os\n", "import uuid\n", - "\n", "import ray\n", - "import torch\n", + "from langchain.document_loaders import ArxivLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", - "from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n", - "\n", - "from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n", - "from google.cloud.sql.connector import IPTypes\n", + "from sentence_transformers import SentenceTransformer\n", + "from typing import List\n", + "import torch\n", + "from datasets import load_dataset_builder, load_dataset, Dataset\n", + "from huggingface_hub import snapshot_download\n", + "from google.cloud.sql.connector import Connector, IPTypes\n", + "import sqlalchemy\n", "\n", "# initialize parameters\n", - "INSTANCE_CONNECTION_NAME = os.environ.get(\"CLOUDSQL_INSTANCE_CONNECTION_NAME\")\n", + "\n", + "INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n", "print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n", - "cloud_variables = INSTANCE_CONNECTION_NAME.split(\":\")\n", - "\n", - "GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\", cloud_variables[0])\n", - "GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\", cloud_variables[1])\n", - "GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\", cloud_variables[2])\n", - "\n", - "DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n", - "VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n", - "CHAT_HISTORY_TABLE_NAME = os.environ.get(\"CHAT_HISTORY_TABLE_NAME\", \"message_store\")\n", - "\n", - "VECTOR_DIMENSION = os.environ.get(\"VECTOR_DIMENSION\", 384)\n", - "\n", - "try:\n", - " db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", - " DB_USER = db_username_file.read()\n", - " db_username_file.close()\n", - "\n", - " db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", - " DB_PASS = db_password_file.read()\n", - " db_password_file.close()\n", - "except:\n", - " DB_USER = os.environ.get(\"DB_USERNAME\", \"postgres\")\n", - " DB_PASS = os.environ.get(\"DB_PASS\", \"postgres\")\n", - "\n", - "engine = PostgresEngine.from_instance(\n", - " project_id=GCP_PROJECT_ID,\n", - " region=GCP_CLOUD_SQL_REGION,\n", - " instance=GCP_CLOUD_SQL_INSTANCE,\n", - " database=DB_NAME,\n", + "DB_NAME = \"pgvector-database\"\n", + "\n", + "db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", + "DB_USER = db_username_file.read()\n", + "db_username_file.close()\n", + "\n", + "db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", + "DB_PASS = db_password_file.read()\n", + "db_password_file.close()\n", + "\n", + "# initialize Connector object\n", + "connector = Connector()\n", + "\n", + "# function to return the database connection object\n", + "def getconn():\n", + " conn = connector.connect(\n", + " INSTANCE_CONNECTION_NAME,\n", + " \"pg8000\",\n", " user=DB_USER,\n", " password=DB_PASS,\n", - " ip_type=IPTypes.PRIVATE,\n", - ")\n", - "\n", - "try:\n", - " engine.init_vectorstore_table(\n", - " VECTOR_EMBEDDINGS_TABLE_NAME,\n", - " vector_size=VECTOR_DIMENSION,\n", - " overwrite_existing=True,\n", + " db=DB_NAME,\n", + " ip_type=IPTypes.PRIVATE\n", " )\n", - "except Exception as err:\n", - " print(f\"Error: {err}\")" + " return conn\n", + "\n", + "# create connection pool with 'creator' argument to our connection object function\n", + "pool = sqlalchemy.create_engine(\n", + " \"postgresql+pg8000://\",\n", + " creator=getconn,\n", + ")" ] }, { @@ -132,7 +113,11 @@ "metadata": {}, "outputs": [], "source": [ + "SHARED_DATA_BASEPATH='/data/rag/st'\n", "SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n", + "SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n", + "SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n", + "SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n", "\n", "# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n", "SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n", @@ -146,15 +131,36 @@ "ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function" ] }, + { + "cell_type": "markdown", + "id": "3dc5bc85-dc3b-4622-99a2-f9fc269e753b", + "metadata": {}, + "source": [ + "Now we will download the sentence transformer model to our GCS bucket:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7a676be-56c6-4c76-8041-9ad05361dd3b", + "metadata": {}, + "outputs": [], + "source": [ + "# prepare the persistent shared directory to store artifacts needed for the ray workers\n", + "os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n", + "\n", + "# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n", + "snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)" + ] + }, { "cell_type": "markdown", "id": "f7304035-21a4-4017-bce9-aba7e9f81c90", "metadata": {}, "source": [ - "## Generating Documents splits\n", + "## Generating Vector Embeddings\n", "\n", - "We are ready to begin. Let's first create some code for generating the dataset splits:\n", - "\n" + "We are ready to begin. Let's first create some code for generating the vector embeddings:" ] }, { @@ -164,18 +170,33 @@ "metadata": {}, "outputs": [], "source": [ - "class Splitter:\n", + "class Embed:\n", " def __init__(self):\n", + " print(\"torch cuda version\", torch.version.cuda)\n", + " device=\"cpu\"\n", + " if torch.cuda.is_available():\n", + " print(\"device cuda found\")\n", + " device=\"cuda\"\n", + "\n", + " print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", + " self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", " self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n", "\n", - " def __call__(self, text_batch):\n", + " def __call__(self, text_batch: List[str]):\n", " text = text_batch[\"item\"]\n", + " # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n", " chunks = []\n", " for data in text:\n", " splits = self.splitter.split_text(data)\n", + " # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n", " chunks.extend(splits)\n", "\n", - " return {'results':chunks}" + " embeddings = self.transformer.encode(\n", + " chunks,\n", + " batch_size=BATCH_SIZE\n", + " ).tolist()\n", + " print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n", + " return {'results':list(zip(chunks, embeddings))}" ] }, { @@ -206,7 +227,6 @@ " \"datasets==2.18.0\",\n", " \"torch==2.0.1\",\n", " \"huggingface_hub==0.21.3\",\n", - " \"langchain-google-cloud-sql-pg\"\n", " ]\n", " }\n", ")" @@ -241,9 +261,9 @@ "}])\n", "print(ds_batch.schema)\n", "\n", - "# Distributed map batches to create chunks out of each row.\n", - "ds_splitted = ds_batch.map_batches(\n", - " Splitter,\n", + "# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n", + "ds_embed = ds_batch.map_batches(\n", + " Embed,\n", " compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n", " batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n", " num_gpus=1, # 1 GPU for each actor.\n", @@ -267,17 +287,17 @@ "outputs": [], "source": [ "@ray.remote\n", - "def ray_data_task(ds_splitted):\n", + "def ray_data_task(ds_embed):\n", " results = []\n", - " for row in ds_splitted.iter_rows():\n", - " data_text = row[\"results\"]\n", - " data_id = str(uuid.uuid4()) \n", + " for row in ds_embed.iter_rows():\n", + " data_text = row[\"results\"][0][:65535]\n", + " data_emb = row[\"results\"][1]\n", "\n", - " results.append((data_id, data_text))\n", + " results.append((data_text, data_emb))\n", " \n", " return results\n", " \n", - "results = ray.get(ray_data_task.remote(ds_splitted))" + "results = ray.get(ray_data_task.remote(ds_embed))" ] }, { @@ -297,23 +317,36 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"torch cuda version\", torch.version.cuda)\n", - "device=\"cpu\"\n", - "if torch.cuda.is_available():\n", - " print(\"device cuda found\")\n", - " device=\"cuda\"\n", + "from sqlalchemy.ext.declarative import declarative_base\n", + "from sqlalchemy import Column, String, Text, text\n", + "from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column\n", + "from pgvector.sqlalchemy import Vector\n", + "\n", + "\n", + "Base = declarative_base()\n", + "DBSession = scoped_session(sessionmaker())\n", + "\n", + "class TextEmbedding(Base):\n", + " __tablename__ = TABLE_NAME\n", + " langchain_id = Column(String(255), primary_key=True)\n", + " page_content = Column(Text)\n", + " embedding = mapped_column(Vector(384))\n", + "\n", + "with pool.connect() as conn:\n", + " conn.execute(text(\"CREATE EXTENSION IF NOT EXISTS vector\"))\n", + " conn.commit() \n", " \n", - "embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL, model_kwargs=dict(device=device))\n", - "vector_store = PostgresVectorStore.create_sync(\n", - " engine=engine,\n", - " embedding_service=embeddings_service,\n", - " table_name=VECTOR_EMBEDDINGS_TABLE_NAME,\n", - ")\n", - "\n", - "for result in results:\n", - " id = result[0]\n", - " splits = result[1]\n", - " vector_store.add_texts(splits, id)" + "DBSession.configure(bind=pool, autoflush=False, expire_on_commit=False)\n", + "Base.metadata.drop_all(pool)\n", + "Base.metadata.create_all(pool)\n", + "\n", + "rows = []\n", + "for r in results:\n", + " id = uuid.uuid4() \n", + " rows.append(TextEmbedding(langchain_id=id, page_content=r[0], embedding=r[1]))\n", + "\n", + "DBSession.bulk_save_objects(rows)\n", + "DBSession.commit()" ] }, { @@ -331,14 +364,21 @@ "metadata": {}, "outputs": [], "source": [ - "query = \"List the cast of squid game\"\n", - "query_vector = embeddings_service.embed_query(query)\n", - "docs = vector_store.similarity_search_by_vector(query_vector, k=4)\n", - "\n", - "for i, document in enumerate(docs):\n", - " print(f\"Result #{i+1}\")\n", - " print(document.page_content)\n", - " print(\"-\" * 100)" + "with pool.connect() as db_conn:\n", + " # verify results\n", + " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", + " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", + " query_emb = transformer.encode(query_text).tolist()\n", + " query_request = \"SELECT langchain_id, page_content, embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", + " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", + " db_conn.commit()\n", + " \n", + " print(\"print query_results, the 1st one is the hit\")\n", + " for row in query_results:\n", + " print(row)\n", + "\n", + "# cleanup connector object\n", + "connector.close()" ] } ], diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 4eb079ade..abd349a9a 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -263,15 +263,15 @@ steps: echo "pass" > /workspace/rag_frontend_result.txt cd /workspace/ - sed -i "s//$$KAGGLE_USERNAME/g" ./applications/rag/example_notebooks/ingest_database.ipynb - sed -i "s//$$KAGGLE_KEY/g" ./applications/rag/example_notebooks/ingest_database.ipynb - gsutil cp ./applications/rag/example_notebooks/ingest_database.ipynb gs://gke-aieco-rag-$SHORT_SHA-$_BUILD_ID/ + sed -i "s//$$KAGGLE_USERNAME/g" ./applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb + sed -i "s//$$KAGGLE_KEY/g" ./applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb + gsutil cp ./applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb gs://gke-aieco-rag-$SHORT_SHA-$_BUILD_ID/ kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID $(kubectl get pod -l app=jupyterhub,component=hub -n rag-$SHORT_SHA-$_BUILD_ID -o jsonpath="{.items[0].metadata.name}") -- jupyterhub token admin --log-level=CRITICAL | xargs python3 ./applications/rag/notebook_starter.py # Wait for jupyterhub to trigger notebook pod startup sleep 5s kubectl wait --for=condition=Ready pod/jupyter-admin -n rag-$SHORT_SHA-$_BUILD_ID --timeout=500s - kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- jupyter nbconvert --to script /data/ingest_database.ipynb - kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- ipython /data/ingest_database.py + kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- jupyter nbconvert --to script /data/rag-kaggle-ray-sql-interactive.ipynb + kubectl exec -it -n rag-$SHORT_SHA-$_BUILD_ID jupyter-admin -c notebook -- ipython /data/rag-kaggle-ray-sql-interactive.py python3 ./applications/rag/tests/test_rag.py "http://127.0.0.1:8081/prompt" echo "pass" > /workspace/rag_prompt_result.txt @@ -399,4 +399,4 @@ availableSecrets: - versionName: projects/gke-ai-eco-dev/secrets/cloudbuild-kaggle-username/versions/latest env: 'KAGGLE_USERNAME' - versionName: projects/gke-ai-eco-dev/secrets/cloudbuild-kaggle-key/versions/latest - env: 'KAGGLE_KEY' + env: 'KAGGLE_KEY' \ No newline at end of file From 0a0478260c6ef43fb21b171e71850d661838e4df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Thu, 12 Sep 2024 18:48:36 +0000 Subject: [PATCH 40/46] Fixing name with column on notebook query --- .../rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index dd717e8d5..d24d31454 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -369,7 +369,7 @@ " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", " query_emb = transformer.encode(query_text).tolist()\n", - " query_request = \"SELECT langchain_id, page_content, embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", + " query_request = \"SELECT langchain_id, page_content, embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", " db_conn.commit()\n", " \n", From c489d73a3311eee5a24ba50a36e93f0bcae022eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Tue, 17 Sep 2024 09:02:39 -0500 Subject: [PATCH 41/46] Delete applications/rag/example_notebooks/ingest_database.ipynb --- applications/rag/example_notebooks/ingest_database.ipynb | 1 - 1 file changed, 1 deletion(-) delete mode 100644 applications/rag/example_notebooks/ingest_database.ipynb diff --git a/applications/rag/example_notebooks/ingest_database.ipynb b/applications/rag/example_notebooks/ingest_database.ipynb deleted file mode 100644 index 04371a7b6..000000000 --- a/applications/rag/example_notebooks/ingest_database.ipynb +++ /dev/null @@ -1 +0,0 @@ -{"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n","os.environ['KAGGLE_USERNAME'] = \"\"\n","os.environ['KAGGLE_KEY'] = \"\"\n","\n","# Download the zip file to local storage and then extract the desired contents directly to the GKE GCS CSI mounted bucket. The bucket is mounted at the \"/persist-data\" path in the jupyter pod.\n","!kaggle datasets download -d shivamb/netflix-shows -p ~/data --force\n","!mkdir /data/netflix-shows -p\n","!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install langchain-google-cloud-sql-pg"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n","import uuid\n","\n","from langchain_community.document_loaders.csv_loader import CSVLoader\n","from langchain.text_splitter import RecursiveCharacterTextSplitter\n","from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n","\n","from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n","from google.cloud.sql.connector import IPTypes\n","\n","# initialize parameters\n","INSTANCE_CONNECTION_NAME = os.environ.get(\"CLOUDSQL_INSTANCE_CONNECTION_NAME\", \"\")\n","print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n","cloud_variables = INSTANCE_CONNECTION_NAME.split(\":\")\n","\n","GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\", cloud_variables[0])\n","GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\", cloud_variables[1])\n","GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\", cloud_variables[2])\n","\n","DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n","VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n","CHAT_HISTORY_TABLE_NAME = os.environ.get(\"CHAT_HISTORY_TABLE_NAME\", \"message_store\")\n","\n","VECTOR_DIMENSION = os.environ.get(\"VECTOR_DIMENSION\", 384)\n","SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' \n","\n","SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n","REVIEWS_FILE_NAME=\"netflix_titles.csv\"\n","\n","BATCH_SIZE = 100\n","CHUNK_SIZE = 1000\n","CHUNK_OVERLAP = 10\n","TABLE_NAME = 'netflix_reviews_db'\n","\n","try:\n"," db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n"," DB_USER = db_username_file.read()\n"," db_username_file.close()\n","\n"," db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n"," DB_PASS = db_password_file.read()\n"," db_password_file.close()\n","except:\n"," DB_USER = os.environ.get(\"DB_USERNAME\", \"postgres\")\n"," DB_PASS = os.environ.get(\"DB_PASS\", \"postgres\")\n","\n","engine = PostgresEngine.from_instance(\n"," project_id=GCP_PROJECT_ID,\n"," region=GCP_CLOUD_SQL_REGION,\n"," instance=GCP_CLOUD_SQL_INSTANCE,\n"," database=DB_NAME,\n"," user=DB_USER,\n"," password=DB_PASS,\n"," ip_type=IPTypes.PRIVATE,\n",")\n","\n","try:\n"," engine.init_vectorstore_table(\n"," VECTOR_EMBEDDINGS_TABLE_NAME,\n"," vector_size=VECTOR_DIMENSION,\n"," overwrite_existing=True,\n"," )\n","except Exception as err:\n"," print(f\"Error: {err}\")\n","\n","\n","embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL)\n","vector_store = PostgresVectorStore.create_sync(\n"," engine=engine,\n"," embedding_service=embeddings_service,\n"," table_name=VECTOR_EMBEDDINGS_TABLE_NAME,\n",")\n","\n","splitter = RecursiveCharacterTextSplitter(\n"," chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len\n",")\n","\n","loader = CSVLoader(file_path=f\"{SHARED_DATASET_BASE_PATH}/{REVIEWS_FILE_NAME}\")\n","documents = loader.load()\n","\n","documents = documents[:1000] #Taking a sample for test purposes \n","\n","splits = splitter.split_documents(documents)\n","ids = [str(uuid.uuid4()) for i in range(len(splits))]\n","vector_store.add_documents(splits, ids)"]}],"metadata":{"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":2} From aa44fcbec0fcef2713fdad32b5166ee310dc1458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 25 Sep 2024 14:41:21 +0000 Subject: [PATCH 42/46] updating Embedding model with missing column --- .../rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index d24d31454..08448c2ef 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -318,7 +318,7 @@ "outputs": [], "source": [ "from sqlalchemy.ext.declarative import declarative_base\n", - "from sqlalchemy import Column, String, Text, text\n", + "from sqlalchemy import Column, String, Text, text, JSON\n", "from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column\n", "from pgvector.sqlalchemy import Vector\n", "\n", @@ -331,6 +331,7 @@ " langchain_id = Column(String(255), primary_key=True)\n", " page_content = Column(Text)\n", " embedding = mapped_column(Vector(384))\n", + " langchain_metadata = Column(JSON) \n", "\n", "with pool.connect() as conn:\n", " conn.execute(text(\"CREATE EXTENSION IF NOT EXISTS vector\"))\n", From 0ef245f4363a858e8050ceaa87e524b6273666c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 2 Oct 2024 18:30:40 +0000 Subject: [PATCH 43/46] Updating packages, improving chain prompt --- .../container/application/rag_langchain/rag_chain.py | 11 ++++++----- applications/rag/frontend/container/requirements.txt | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py index 36cc10859..1c30a05d8 100644 --- a/applications/rag/frontend/container/application/rag_langchain/rag_chain.py +++ b/applications/rag/frontend/container/application/rag_langchain/rag_chain.py @@ -18,7 +18,7 @@ from langchain_core.runnables import RunnableParallel, RunnableLambda from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_huggingface.embeddings import HuggingFaceEmbeddings from langchain_google_cloud_sql_pg import PostgresChatMessageHistory @@ -42,10 +42,11 @@ SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small" # Transformer to use for converting text chunks to vector embeddings -template_str = """Answer the question given by the user in no more than 2 sentences. -Use the provided context to improve upon your previous answers. Stick to the facts and be brief. Avoid conversational format. -\n\n -Context: {context} +template_str = """Provide a concise answer to the user's question in 1-2 sentences, +focusing strictly on factual information from the given context. +Prioritize accuracy, avoid unnecessary details, and eliminate conversational language. +Stick to the content of the context for your response.\n +Context: {context} """ prompt = ChatPromptTemplate.from_messages( diff --git a/applications/rag/frontend/container/requirements.txt b/applications/rag/frontend/container/requirements.txt index 39773fecc..c64edad78 100644 --- a/applications/rag/frontend/container/requirements.txt +++ b/applications/rag/frontend/container/requirements.txt @@ -15,8 +15,8 @@ Flask==3.0.0 gunicorn==22.0.0 Werkzeug==3.0.3 -langchain==0.1.9 -sentence-transformers==2.5.1 +langchain +sentence-transformers google-cloud-dlp==3.12.2 google-cloud-storage==2.9.0 google-cloud-pubsub==2.17.0 @@ -30,6 +30,6 @@ google-cloud==0.34.0 google-cloud-logging==3.9.0 google-api-python-client==2.114.0 pymysql==1.1.1 -cloud-sql-python-connector[pg8000]==1.7.0 -langchain-google-cloud-sql-pg==0.4.0 -langchain-community==0.0.31 \ No newline at end of file +cloud-sql-python-connector[pg8000] +langchain-google-cloud-sql-pg +langchain-huggingface \ No newline at end of file From 8558a464f12f21cf681a5e8802be58757ad2ae5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Wed, 2 Oct 2024 18:54:41 +0000 Subject: [PATCH 44/46] updating rag frontend sha --- applications/rag/frontend/main.tf | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 18f7c759c..65544646b 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -118,9 +118,8 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { service_account_name = var.google_service_account container { # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" - image = "us-docker.pkg.dev/globant-gke-ai-resources/gke-ai-text-to-text/gke-rag-frontend@sha256:e56c59747b1ecc192458a3fdd6c74ad6a2099eeabb61e3fd1eb5fc30a147ba1d" - # Built from local code. Revert before submitting. - # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:108bb16ee2278255c80524fce125ef349c494cb5bc4ca77dbde5048b8f9448c1" + image = "us-docker.pkg.dev/globant-gke-ai-resources/gke-ai-text-to-text/gke-rag-frontend@sha256:6dfc6612ad43d0d5f26e83b1c7c3fc6d8fafe869c21e031c971284cecd9a6bda" + name = "rag-frontend" port { From 44b3e7217d2fe5a39a58d059792c56c6a3060cb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Tue, 8 Oct 2024 17:33:58 +0000 Subject: [PATCH 45/46] updating column name --- .../example_notebooks/rag-kaggle-ray-sql-interactive.ipynb | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb index 2d6f64b3b..ee031241a 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -7,6 +7,7 @@ "source": [ "# RAG-on-GKE Application\n", "\n", + "\n", "This is a Python notebook for generating the vector embeddings used by the RAG on GKE application. For full information, please checkout the GitHub documentation [here](https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/applications/rag/README.md).\n", "\n", "\n", @@ -328,7 +329,7 @@ "class TextEmbedding(Base):\n", " __tablename__ = TABLE_NAME\n", " langchain_id = Column(String(255), primary_key=True)\n", - " page_content = Column(Text)\n", + " content = Column(Text)\n", " embedding = mapped_column(Vector(384))\n", " langchain_metadata = Column(JSON) \n", "\n", @@ -343,7 +344,7 @@ "rows = []\n", "for r in results:\n", " id = uuid.uuid4() \n", - " rows.append(TextEmbedding(langchain_id=id, page_content=r[0], embedding=r[1]))\n", + " rows.append(TextEmbedding(langchain_id=id, content=r[0], embedding=r[1]))\n", "\n", "DBSession.bulk_save_objects(rows)\n", "DBSession.commit()" @@ -369,7 +370,7 @@ " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", " query_emb = transformer.encode(query_text).tolist()\n", - " query_request = \"SELECT langchain_id, page_content, embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", + " query_request = \"SELECT langchain_id, content, embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", " db_conn.commit()\n", " \n", From ab06c079d8ece2b247c0fef3ef519139ca60af1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Grandas?= Date: Thu, 24 Oct 2024 23:53:04 +0000 Subject: [PATCH 46/46] updating max tokens lenght for inference service --- .../rag_langchain/huggingface_inference_model.py | 2 +- applications/rag/frontend/main.tf | 2 +- modules/inference-service/main.tf | 9 +++++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py b/applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py index 2902270aa..8de9972a5 100644 --- a/applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py +++ b/applications/rag/frontend/container/application/rag_langchain/huggingface_inference_model.py @@ -67,7 +67,7 @@ def _call( raise ValueError("stop kwargs are not permitted.") api_endpoint = f"http://{INFERENCE_ENDPOINT}/generate" - body = {"inputs": prompt} + body = {"inputs": prompt, "parameters":{ "max_new_tokens": 2048 }} headers = {"Content-Type": "application/json"} generated_output = post_request(api_endpoint, body, headers) generated_text = generated_output.get("generated_text", "") diff --git a/applications/rag/frontend/main.tf b/applications/rag/frontend/main.tf index 65544646b..d101b1a2d 100644 --- a/applications/rag/frontend/main.tf +++ b/applications/rag/frontend/main.tf @@ -118,7 +118,7 @@ resource "kubernetes_deployment" "rag_frontend_deployment" { service_account_name = var.google_service_account container { # image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/frontend@sha256:d65b538742ee29826ee629cfe05c0008e7c09ce5357ddc08ea2eaf3fd6cefe4b" - image = "us-docker.pkg.dev/globant-gke-ai-resources/gke-ai-text-to-text/gke-rag-frontend@sha256:6dfc6612ad43d0d5f26e83b1c7c3fc6d8fafe869c21e031c971284cecd9a6bda" + image = "us-docker.pkg.dev/globant-gke-ai-resources/gke-ai-text-to-text/gke-rag-frontend@sha256:e6960ded132211f02d7442177bd6454ccd65b29768e62ebb47f14c17a30a46f2" name = "rag-frontend" diff --git a/modules/inference-service/main.tf b/modules/inference-service/main.tf index 91e558369..14058f360 100644 --- a/modules/inference-service/main.tf +++ b/modules/inference-service/main.tf @@ -110,6 +110,15 @@ resource "kubernetes_deployment" "inference_deployment" { value = "2" } + env { + name = "MAX_INPUT_LENGTH" + value = "2048" + } + env { + name = "MAX_TOTAL_TOKENS" + value = "4096" + } + env { name = "PORT" value = "8080"