From 28c5f7e347e3d908668d7e165ab92d9a02805e60 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Thu, 26 Sep 2024 21:20:36 -0300 Subject: [PATCH] add option to change embedding model for both local and openai --- app/pages/Settings.py | 53 ++++++++++++++++++- app/util/constants.py | 1 + app/util/openai_wrapper.py | 3 ++ app/util/secrets_handler.py | 2 +- .../detect_entity_networks/functions.py | 4 ++ .../detect_entity_networks/workflow.py | 13 ++--- .../match_entity_records/functions.py | 4 ++ app/workflows/query_text_data/functions.py | 4 ++ app/workflows/query_text_data/workflow.py | 2 +- toolkit/AI/local_embedder.py | 3 +- toolkit/AI/openai_configuration.py | 12 +++++ toolkit/AI/openai_embedder.py | 8 ++- .../detect_entity_networks/index_and_infer.py | 2 +- 13 files changed, 98 insertions(+), 13 deletions(-) diff --git a/app/pages/Settings.py b/app/pages/Settings.py index 080ec63c..f1b16ffa 100644 --- a/app/pages/Settings.py +++ b/app/pages/Settings.py @@ -7,12 +7,13 @@ import streamlit as st from components.app_loader import load_multipage_app -from util.constants import MAX_SIZE_EMBEDDINGS_KEY +from util.constants import LOCAL_EMBEDDING_MODEL_KEY, MAX_SIZE_EMBEDDINGS_KEY from util.enums import Mode from util.openai_wrapper import ( UIOpenAIConfiguration, key, openai_azure_auth_type, + openai_embedding_model, openai_endpoint_key, openai_model_key, openai_type_key, @@ -20,9 +21,30 @@ ) from util.secrets_handler import SecretsHandler +from toolkit.AI.defaults import DEFAULT_LOCAL_EMBEDDING_MODEL from toolkit.AI.vector_store import VectorStore from toolkit.helpers.constants import CACHE_PATH +openai_embedding_models = [ + "text-embedding-3-large", + "text-embedding-3-small", + "text-embedding-ada-002", +] + +local_embedding_models = [ + "all-mpnet-base-v2", + "multi-qa-mpnet-base-dot-v1", + "all-distilroberta-v1", + "all-MiniLM-L12-v2", + "multi-qa-distilbert-cos-v1", + "paraphrase-multilingual-mpnet-base-v2", + "paraphrase-albert-small-v2", + "paraphrase-multilingual-MiniLM-L12-v2", + "paraphrase-MiniLM-L3-v2", + "distiluse-base-multilingual-cased-v1", + "distiluse-base-multilingual-cased-v2", +] + def on_change(handler, key=None, value=None): def change(): @@ -148,6 +170,35 @@ def main(): st.subheader("Embeddings") max_size = int(secrets_handler.get_secret(MAX_SIZE_EMBEDDINGS_KEY) or 0) + local_embedding_value = ( + secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY) + or DEFAULT_LOCAL_EMBEDDING_MODEL + ) + + st.markdown("Select the embedding model you want to use.") + e1, e2 = st.columns(2) + with e1: + openai_model = st.selectbox( + "OpenAI embedding model", + openai_embedding_models, + index=openai_embedding_models.index(openai_config.embedding_model), + key="openai_embedding_model", + help="Select the embedding model you want to use.", + ) + if openai_model != openai_config.embedding_model: + on_change(secrets_handler, openai_embedding_model, openai_model)() + st.rerun() + with e2: + local_model = st.selectbox( + "Local embedding model", + local_embedding_models, + index=local_embedding_models.index(local_embedding_value), + key="local_embedding_model", + help="Select the embedding model you want to use.", + ) + if local_model != local_embedding_value: + on_change(secrets_handler, LOCAL_EMBEDDING_MODEL_KEY, local_model)() + st.rerun() c1, c2 = st.columns(2) with c1: diff --git a/app/util/constants.py b/app/util/constants.py index 0cd2cd6d..b2592fde 100644 --- a/app/util/constants.py +++ b/app/util/constants.py @@ -7,3 +7,4 @@ PDF_ENCODING = "UTF-8" PDF_WKHTMLTOPDF_PATH = "C:\\Program Files\\wkhtmltopdf\\bin\\wkhtmltopdf.exe" MAX_SIZE_EMBEDDINGS_KEY = "max_embedding" +LOCAL_EMBEDDING_MODEL_KEY = "local_embedding_model" diff --git a/app/util/openai_wrapper.py b/app/util/openai_wrapper.py index 27e79efc..af8b7d0c 100644 --- a/app/util/openai_wrapper.py +++ b/app/util/openai_wrapper.py @@ -11,6 +11,7 @@ openai_endpoint_key = "openai_endpoint" openai_model_key = "openai_model" openai_azure_auth_type = "openai_azure_auth_type" +openai_embedding_model = "openai_embedding_model" class UIOpenAIConfiguration: @@ -26,6 +27,7 @@ def get_configuration(self) -> OpenAIConfiguration: secret_key = self._secrets.get_secret(key) or None model = self._secrets.get_secret(openai_model_key) or None az_auth_type = self._secrets.get_secret(openai_azure_auth_type) or None + embedding_model = self._secrets.get_secret(openai_embedding_model) or None config = { "api_type": api_type, @@ -34,6 +36,7 @@ def get_configuration(self) -> OpenAIConfiguration: "api_key": secret_key, "model": model, "az_auth_type": az_auth_type, + "embedding_model": embedding_model, } values = {k: v for k, v in config.items() if v is not None} return OpenAIConfiguration(values) diff --git a/app/util/secrets_handler.py b/app/util/secrets_handler.py index 1f8706dd..bcee1533 100644 --- a/app/util/secrets_handler.py +++ b/app/util/secrets_handler.py @@ -28,7 +28,7 @@ def read_values_from_file(self): def get_secret(self, value): values = self.read_values_from_file() - for key in self.read_values_from_file(): + for key in values: if key == value: return values[key] return "" diff --git a/app/workflows/detect_entity_networks/functions.py b/app/workflows/detect_entity_networks/functions.py index 9bdc1f49..977eb58c 100644 --- a/app/workflows/detect_entity_networks/functions.py +++ b/app/workflows/detect_entity_networks/functions.py @@ -5,6 +5,8 @@ from util.openai_wrapper import UIOpenAIConfiguration import toolkit.detect_entity_networks.config as config +from app.util.constants import LOCAL_EMBEDDING_MODEL_KEY +from app.util.secrets_handler import SecretsHandler from toolkit.AI.base_embedder import BaseEmbedder from toolkit.AI.local_embedder import LocalEmbedder from toolkit.AI.openai_embedder import OpenAIEmbedder @@ -13,11 +15,13 @@ def embedder(local_embedding: bool | None = True) -> BaseEmbedder: try: ai_configuration = UIOpenAIConfiguration().get_configuration() + secrets_handler = SecretsHandler() if local_embedding: return LocalEmbedder( db_name=config.cache_name, max_tokens=ai_configuration.max_tokens, concurrent_coroutines=80, + model=secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY) or None, ) return OpenAIEmbedder( configuration=ai_configuration, diff --git a/app/workflows/detect_entity_networks/workflow.py b/app/workflows/detect_entity_networks/workflow.py index 18129526..35e63145 100644 --- a/app/workflows/detect_entity_networks/workflow.py +++ b/app/workflows/detect_entity_networks/workflow.py @@ -233,12 +233,13 @@ async def create(sv: rn_variables.SessionVariables, workflow=None): help="Select the node types to embed into a multi-dimensional semantic space for fuzzy matching.", ) - total_embeddings = sum( - 1 - for _, data in sv.network_overall_graph.value.nodes(data=True) - if data["type"] in network_indexed_node_types - ) - st.caption(f"Total of {total_embeddings} nodes to index") + if sv.network_overall_graph.value and network_indexed_node_types: + total_embeddings = sum( + 1 + for _, data in sv.network_overall_graph.value.nodes(data=True) + if data["type"] in network_indexed_node_types + ) + st.caption(f"Total of {total_embeddings} nodes to index") local_embedding = st.toggle( "Use local embeddings", sv.network_local_embedding_enabled.value, diff --git a/app/workflows/match_entity_records/functions.py b/app/workflows/match_entity_records/functions.py index d2fc69cd..052c8522 100644 --- a/app/workflows/match_entity_records/functions.py +++ b/app/workflows/match_entity_records/functions.py @@ -3,7 +3,9 @@ # import streamlit as st +from app.util.constants import LOCAL_EMBEDDING_MODEL_KEY from app.util.openai_wrapper import UIOpenAIConfiguration +from app.util.secrets_handler import SecretsHandler from toolkit.AI.base_embedder import BaseEmbedder from toolkit.AI.local_embedder import LocalEmbedder from toolkit.AI.openai_embedder import OpenAIEmbedder @@ -13,10 +15,12 @@ def embedder(local_embedding: bool | None = False) -> BaseEmbedder: try: ai_configuration = UIOpenAIConfiguration().get_configuration() + secrets_handler = SecretsHandler() if local_embedding: return LocalEmbedder( db_name=config.cache_name, max_tokens=ai_configuration.max_tokens, + model=secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY) or None, ) return OpenAIEmbedder( configuration=ai_configuration, diff --git a/app/workflows/query_text_data/functions.py b/app/workflows/query_text_data/functions.py index 35411a97..ad0e80eb 100644 --- a/app/workflows/query_text_data/functions.py +++ b/app/workflows/query_text_data/functions.py @@ -3,7 +3,9 @@ # import streamlit as st +from app.util.constants import LOCAL_EMBEDDING_MODEL_KEY from app.util.openai_wrapper import UIOpenAIConfiguration +from app.util.secrets_handler import SecretsHandler from toolkit.AI.base_embedder import BaseEmbedder from toolkit.AI.local_embedder import LocalEmbedder from toolkit.AI.openai_embedder import OpenAIEmbedder @@ -13,10 +15,12 @@ def embedder(local_embedding: bool | None = False) -> BaseEmbedder: try: ai_configuration = UIOpenAIConfiguration().get_configuration() + secrets_handler = SecretsHandler() if local_embedding: return LocalEmbedder( db_name=config.cache_name, max_tokens=ai_configuration.max_tokens, + model=secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY) or None, ) return OpenAIEmbedder( configuration=ai_configuration, diff --git a/app/workflows/query_text_data/workflow.py b/app/workflows/query_text_data/workflow.py index d213f25c..b386d495 100644 --- a/app/workflows/query_text_data/workflow.py +++ b/app/workflows/query_text_data/workflow.py @@ -237,7 +237,7 @@ async def create(sv: SessionVariables, workflow=None): sv.cid_to_vector.value = await helper_functions.embed_texts( sv.cid_to_explained_text.value, text_embedder, - config.cache_name, + sv_home.save_cache.value, callbacks=[embed_callback], ) chunk_pb.empty() diff --git a/toolkit/AI/local_embedder.py b/toolkit/AI/local_embedder.py index 6f94fa87..476dce40 100644 --- a/toolkit/AI/local_embedder.py +++ b/toolkit/AI/local_embedder.py @@ -23,9 +23,10 @@ def __init__( db_path=CACHE_PATH, max_tokens=DEFAULT_LLM_MAX_TOKENS, concurrent_coroutines: int | None = DEFAULT_CONCURRENT_COROUTINES, + model: str | None = DEFAULT_LOCAL_EMBEDDING_MODEL, ): super().__init__(db_name, db_path, max_tokens, concurrent_coroutines) - self.local_client = SentenceTransformer(DEFAULT_LOCAL_EMBEDDING_MODEL) + self.local_client = SentenceTransformer(model) def _generate_embedding(self, text: str | list[str]) -> list | Any: return self.local_client.encode(text).tolist() diff --git a/toolkit/AI/openai_configuration.py b/toolkit/AI/openai_configuration.py index be8cd039..b635d61e 100644 --- a/toolkit/AI/openai_configuration.py +++ b/toolkit/AI/openai_configuration.py @@ -5,6 +5,7 @@ from .defaults import ( DEFAULT_AZ_AUTH_TYPE, + DEFAULT_EMBEDDING_MODEL, DEFAULT_LLM_MAX_TOKENS, DEFAULT_LLM_MODEL, DEFAULT_OPENAI_VERSION, @@ -33,6 +34,7 @@ class OpenAIConfiguration: _max_tokens: int | None _api_type: str _az_auth_type: str + _embedding_model: str def __init__( self, @@ -53,6 +55,9 @@ def __init__( self._max_tokens = config.get("max_tokens", DEFAULT_LLM_MAX_TOKENS) self._az_auth_type = config.get("az_auth_type", self._get_az_auth_type()) self._api_type = config.get("api_type", oai_type) + self._embedding_model = config.get( + "embedding_model", self._get_embedding_model() + ) def _get_openai_type(self): return os.environ.get("OPENAI_TYPE", "OpenAI") @@ -66,6 +71,9 @@ def _get_azure_openai_version(self): def _get_chat_model(self): return os.environ.get("OPENAI_API_MODEL", DEFAULT_LLM_MODEL) + def _get_embedding_model(self): + return os.environ.get("OPENAI_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL) + def _get_azure_api_base(self): return os.environ.get("AZURE_OPENAI_ENDPOINT", "") @@ -104,6 +112,10 @@ def max_tokens(self) -> int | None: """Max tokens property definition.""" return self._max_tokens + @property + def embedding_model(self) -> str | None: + return self._embedding_model + @property def api_type(self) -> str | None: """Type of the AI connection.""" diff --git a/toolkit/AI/openai_embedder.py b/toolkit/AI/openai_embedder.py index 5d43a293..80e290ef 100644 --- a/toolkit/AI/openai_embedder.py +++ b/toolkit/AI/openai_embedder.py @@ -25,7 +25,11 @@ def __init__( self.openai_client = OpenAIClient(configuration) def _generate_embedding(self, text: str) -> list[float]: - return self.openai_client.generate_embedding(text) + return self.openai_client.generate_embedding( + text, model=self.configuration.embedding_model + ) async def _generate_embedding_async(self, text: str) -> list[float]: - return await self.openai_client.generate_embedding_async(text) \ No newline at end of file + return await self.openai_client.generate_embedding_async( + text, model=self.configuration.embedding_model + ) \ No newline at end of file diff --git a/toolkit/detect_entity_networks/index_and_infer.py b/toolkit/detect_entity_networks/index_and_infer.py index bd673e4f..00c06846 100644 --- a/toolkit/detect_entity_networks/index_and_infer.py +++ b/toolkit/detect_entity_networks/index_and_infer.py @@ -37,7 +37,7 @@ async def index_nodes( if data["type"] in indexed_node_types ] text_types.sort() - texts = [t[0] for t in text_types] + texts = [text_type[0] for text_type in text_types] data = [ {