From fe3c8b3bd4726995164a933256c27d85c7a1e8ee Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Wed, 25 Sep 2024 17:51:58 -0300 Subject: [PATCH] add per workflow local embedding mode (#55) * add local embeddings in workflow for MER * Add inline local embedding for QTD --- app/components/app_mode.py | 12 +---- app/util/session_variables.py | 1 - .../detect_entity_networks/functions.py | 3 -- .../match_entity_records/functions.py | 7 +-- .../match_entity_records/variables.py | 1 + .../match_entity_records/workflow.py | 7 ++- app/workflows/query_text_data/functions.py | 27 +++++++++++ app/workflows/query_text_data/variables.py | 2 + app/workflows/query_text_data/workflow.py | 46 ++++++------------- toolkit/AI/base_chat.py | 7 ++- toolkit/AI/local_embedder.py | 8 +++- .../query_text_data/config.py | 0 12 files changed, 65 insertions(+), 56 deletions(-) create mode 100644 app/workflows/query_text_data/functions.py rename {app/workflows => toolkit}/query_text_data/config.py (100%) diff --git a/app/components/app_mode.py b/app/components/app_mode.py index 825425dc..c73b3f79 100644 --- a/app/components/app_mode.py +++ b/app/components/app_mode.py @@ -22,16 +22,6 @@ def config(self): value=self.sv.save_cache.value, help="Enable caching of embeddings to speed up the application.", ) - local_embed = st.sidebar.toggle( - "Use local embeddings", - value=self.sv.local_embeddings.value, - help="Don't call OpenAI to embed, use a local library.", - ) - if cache != self.sv.save_cache.value: self.sv.save_cache.value = cache - st.rerun() - - if local_embed != self.sv.local_embeddings.value: - self.sv.local_embeddings.value = local_embed - st.rerun() + st.rerun() \ No newline at end of file diff --git a/app/util/session_variables.py b/app/util/session_variables.py index 86153080..c3410312 100644 --- a/app/util/session_variables.py +++ b/app/util/session_variables.py @@ -14,4 +14,3 @@ def __init__(self, prefix=""): self.embedding_model = sv.SessionVariable("text-embedding-ada-002") self.max_embedding_size = sv.SessionVariable(500) self.save_cache = sv.SessionVariable(True) - self.local_embeddings = sv.SessionVariable(False) diff --git a/app/workflows/detect_entity_networks/functions.py b/app/workflows/detect_entity_networks/functions.py index 91793328..9bdc1f49 100644 --- a/app/workflows/detect_entity_networks/functions.py +++ b/app/workflows/detect_entity_networks/functions.py @@ -3,15 +3,12 @@ # import streamlit as st from util.openai_wrapper import UIOpenAIConfiguration -from util.session_variables import SessionVariables import toolkit.detect_entity_networks.config as config from toolkit.AI.base_embedder import BaseEmbedder from toolkit.AI.local_embedder import LocalEmbedder from toolkit.AI.openai_embedder import OpenAIEmbedder -sv_home = SessionVariables("home") - def embedder(local_embedding: bool | None = True) -> BaseEmbedder: try: diff --git a/app/workflows/match_entity_records/functions.py b/app/workflows/match_entity_records/functions.py index 9100c16f..d2fc69cd 100644 --- a/app/workflows/match_entity_records/functions.py +++ b/app/workflows/match_entity_records/functions.py @@ -4,19 +4,16 @@ import streamlit as st from app.util.openai_wrapper import UIOpenAIConfiguration -from app.util.session_variables import SessionVariables from toolkit.AI.base_embedder import BaseEmbedder from toolkit.AI.local_embedder import LocalEmbedder from toolkit.AI.openai_embedder import OpenAIEmbedder from toolkit.match_entity_records import config -sv_home = SessionVariables("home") - -def embedder() -> BaseEmbedder: +def embedder(local_embedding: bool | None = False) -> BaseEmbedder: try: ai_configuration = UIOpenAIConfiguration().get_configuration() - if sv_home.local_embeddings.value: + if local_embedding: return LocalEmbedder( db_name=config.cache_name, max_tokens=ai_configuration.max_tokens, diff --git a/app/workflows/match_entity_records/variables.py b/app/workflows/match_entity_records/variables.py index 63dc51e9..50221752 100644 --- a/app/workflows/match_entity_records/variables.py +++ b/app/workflows/match_entity_records/variables.py @@ -38,6 +38,7 @@ def create_session(self, prefix): self.matching_report_validation_messages = SessionVariable("", prefix) self.matching_system_prompt = SessionVariable(prompts.list_prompts, prefix) self.matching_upload_key = SessionVariable(random.randint(1, 100), prefix) + self.matching_local_embedding_enabled = SessionVariable(False, prefix) def reset_workflow(self): for key in st.session_state: diff --git a/app/workflows/match_entity_records/workflow.py b/app/workflows/match_entity_records/workflow.py index 684ac4f4..9328da59 100644 --- a/app/workflows/match_entity_records/workflow.py +++ b/app/workflows/match_entity_records/workflow.py @@ -204,6 +204,11 @@ def att_ui(i, any_empty, changed, attsaa): st.rerun() attributes_list = build_attribute_list(attsa) + local_embedding = st.toggle( + "Use local embeddings", + sv.matching_local_embedding_enabled.value, + help="Use local embeddings to index nodes. If disabled, the model will use OpenAI embeddings.", + ) st.markdown("##### Configure similarity thresholds") b1, b2 = st.columns([1, 1]) with b1: @@ -277,7 +282,7 @@ def on_embedding_batch_change(current, total): callback = ProgressBatchCallback() callback.on_batch_change = on_embedding_batch_change - functions_embedder = functions.embedder() + functions_embedder = functions.embedder(local_embedding) data_embeddings = await functions_embedder.embed_store_many( all_sentences_data, [callback], sv_home.save_cache.value ) diff --git a/app/workflows/query_text_data/functions.py b/app/workflows/query_text_data/functions.py new file mode 100644 index 00000000..35411a97 --- /dev/null +++ b/app/workflows/query_text_data/functions.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project. +# +import streamlit as st + +from app.util.openai_wrapper import UIOpenAIConfiguration +from toolkit.AI.base_embedder import BaseEmbedder +from toolkit.AI.local_embedder import LocalEmbedder +from toolkit.AI.openai_embedder import OpenAIEmbedder +from toolkit.query_text_data import config + + +def embedder(local_embedding: bool | None = False) -> BaseEmbedder: + try: + ai_configuration = UIOpenAIConfiguration().get_configuration() + if local_embedding: + return LocalEmbedder( + db_name=config.cache_name, + max_tokens=ai_configuration.max_tokens, + ) + return OpenAIEmbedder( + configuration=ai_configuration, + db_name=config.cache_name, + ) + except Exception as e: + st.error(f"Error creating connection: {e}") + st.stop() diff --git a/app/workflows/query_text_data/variables.py b/app/workflows/query_text_data/variables.py index c92622bb..57bf0e1b 100644 --- a/app/workflows/query_text_data/variables.py +++ b/app/workflows/query_text_data/variables.py @@ -56,6 +56,8 @@ def create_session(self, prefix): self.chunk_progress = SessionVariable("", prefix) self.answer_progress = SessionVariable("", prefix) + self.answer_local_embedding_enabled = SessionVariable(False, prefix) + def reset_workflow(self): for key in st.session_state: if key.startswith(self.prefix): diff --git a/app/workflows/query_text_data/workflow.py b/app/workflows/query_text_data/workflow.py index 696e3245..d213f25c 100644 --- a/app/workflows/query_text_data/workflow.py +++ b/app/workflows/query_text_data/workflow.py @@ -8,27 +8,25 @@ from seaborn import color_palette from streamlit_agraph import Config, Edge, Node, agraph +import app.util.example_outputs_ui as example_outputs_ui +import app.workflows.query_text_data.functions as functions +import toolkit.query_text_data.answer_builder as answer_builder +import toolkit.query_text_data.graph_builder as graph_builder import toolkit.query_text_data.helper_functions as helper_functions import toolkit.query_text_data.input_processor as input_processor import toolkit.query_text_data.prompts as prompts -import toolkit.query_text_data.answer_builder as answer_builder import toolkit.query_text_data.relevance_assessor as relevance_assessor -import toolkit.query_text_data.graph_builder as graph_builder -import app.util.example_outputs_ui as example_outputs_ui -from toolkit.helpers.progress_batch_callback import ProgressBatchCallback from app.util import ui_components from app.util.download_pdf import add_download_pdf from app.util.openai_wrapper import UIOpenAIConfiguration from app.util.session_variables import SessionVariables -from app.workflows.query_text_data import config -from toolkit.AI.base_embedder import BaseEmbedder from toolkit.AI.defaults import CHUNK_SIZE -from toolkit.AI.local_embedder import LocalEmbedder -from toolkit.AI.openai_embedder import OpenAIEmbedder from toolkit.graph.graph_fusion_encoder_embedding import ( + create_concept_to_community_hierarchy, generate_graph_fusion_encoder_embedding, - create_concept_to_community_hierarchy ) +from toolkit.helpers.progress_batch_callback import ProgressBatchCallback +from toolkit.query_text_data import config from toolkit.query_text_data.pattern_detector import ( combine_chunk_text_and_explantion, detect_converging_pairs, @@ -60,27 +58,6 @@ def on_change(current, total): return pb, callback -def embedder() -> BaseEmbedder: - try: - ai_configuration = UIOpenAIConfiguration().get_configuration() - if sv_home.local_embeddings.value: - return LocalEmbedder( - db_name=config.cache_name, - max_tokens=ai_configuration.max_tokens, - ) - - return OpenAIEmbedder( - configuration=ai_configuration, - db_name=config.cache_name, - ) - except Exception as e: - st.error(f"Error creating connection: {e}") - st.stop() - - -text_embedder = embedder() - - def get_concept_graph( placeholder, G, concept_to_community, community_to_concepts, width, height, key ): @@ -172,6 +149,11 @@ async def create(sv: SessionVariables, workflow=None): # ) # window_period = input_processor.PeriodOption[window_size] window_period = input_processor.PeriodOption.NONE + local_embedding = st.toggle( + "Use local embeddings", + sv.answer_local_embedding_enabled.value, + help="Use local embeddings to index nodes. If disabled, the model will use OpenAI embeddings.", + ) if files is not None and st.button("Process files"): file_pb, file_callback = create_progress_callback( "Loaded {} of {} files..." @@ -250,6 +232,8 @@ async def create(sv: SessionVariables, workflow=None): embed_pb, embed_callback = create_progress_callback( "Embedded {} of {} text chunks..." ) + text_embedder = functions.embedder(local_embedding) + sv.cid_to_vector.value = await helper_functions.embed_texts( sv.cid_to_explained_text.value, text_embedder, @@ -454,7 +438,7 @@ def on_answer(message): concept_to_community=sv.concept_to_community.value, previous_cid=sv.previous_cid.value, next_cid=sv.next_cid.value, - embedder=embedder(), + embedder=functions.embedder(local_embedding), embedding_cache=sv_home.save_cache.value, select_logit_bias=5, adjacent_search_steps=sv.adjacent_chunk_steps.value, diff --git a/toolkit/AI/base_chat.py b/toolkit/AI/base_chat.py index 85272f69..57c2d32c 100644 --- a/toolkit/AI/base_chat.py +++ b/toolkit/AI/base_chat.py @@ -9,13 +9,16 @@ from toolkit.AI.base_batch_async import BaseBatchAsync from toolkit.AI.client import OpenAIClient +from toolkit.AI.defaults import DEFAULT_CONCURRENT_COROUTINES from toolkit.helpers.decorators import retry_with_backoff from toolkit.helpers.progress_batch_callback import ProgressBatchCallback class BaseChat(BaseBatchAsync, OpenAIClient): - def __init__(self, configuration=None, concurrent_coroutines=20) -> None: - OpenAIClient.__init__(self, configuration, concurrent_coroutines) + def __init__( + self, configuration=None, concurrent_coroutines=DEFAULT_CONCURRENT_COROUTINES + ) -> None: + OpenAIClient.__init__(self, configuration) self.semaphore = asyncio.Semaphore(concurrent_coroutines) @retry_with_backoff() diff --git a/toolkit/AI/local_embedder.py b/toolkit/AI/local_embedder.py index 13d8b530..6f94fa87 100644 --- a/toolkit/AI/local_embedder.py +++ b/toolkit/AI/local_embedder.py @@ -8,7 +8,11 @@ from sentence_transformers import SentenceTransformer from toolkit.AI.base_embedder import BaseEmbedder -from toolkit.AI.defaults import DEFAULT_LLM_MAX_TOKENS, DEFAULT_LOCAL_EMBEDDING_MODEL +from toolkit.AI.defaults import ( + DEFAULT_CONCURRENT_COROUTINES, + DEFAULT_LLM_MAX_TOKENS, + DEFAULT_LOCAL_EMBEDDING_MODEL, +) from toolkit.helpers.constants import CACHE_PATH @@ -18,7 +22,7 @@ def __init__( db_name: str = "embeddings", db_path=CACHE_PATH, max_tokens=DEFAULT_LLM_MAX_TOKENS, - concurrent_coroutines: int | None = None, + concurrent_coroutines: int | None = DEFAULT_CONCURRENT_COROUTINES, ): super().__init__(db_name, db_path, max_tokens, concurrent_coroutines) self.local_client = SentenceTransformer(DEFAULT_LOCAL_EMBEDDING_MODEL) diff --git a/app/workflows/query_text_data/config.py b/toolkit/query_text_data/config.py similarity index 100% rename from app/workflows/query_text_data/config.py rename to toolkit/query_text_data/config.py