Skip to content

Commit

Permalink
add per workflow local embedding mode (#55)
Browse files Browse the repository at this point in the history
* add local embeddings in workflow for MER

* Add inline local embedding for QTD
  • Loading branch information
dayesouza authored Sep 25, 2024
1 parent f063e2d commit fe3c8b3
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 56 deletions.
12 changes: 1 addition & 11 deletions app/components/app_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@ def config(self):
value=self.sv.save_cache.value,
help="Enable caching of embeddings to speed up the application.",
)
local_embed = st.sidebar.toggle(
"Use local embeddings",
value=self.sv.local_embeddings.value,
help="Don't call OpenAI to embed, use a local library.",
)

if cache != self.sv.save_cache.value:
self.sv.save_cache.value = cache
st.rerun()

if local_embed != self.sv.local_embeddings.value:
self.sv.local_embeddings.value = local_embed
st.rerun()
st.rerun()
1 change: 0 additions & 1 deletion app/util/session_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ def __init__(self, prefix=""):
self.embedding_model = sv.SessionVariable("text-embedding-ada-002")
self.max_embedding_size = sv.SessionVariable(500)
self.save_cache = sv.SessionVariable(True)
self.local_embeddings = sv.SessionVariable(False)
3 changes: 0 additions & 3 deletions app/workflows/detect_entity_networks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
#
import streamlit as st
from util.openai_wrapper import UIOpenAIConfiguration
from util.session_variables import SessionVariables

import toolkit.detect_entity_networks.config as config
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder

sv_home = SessionVariables("home")


def embedder(local_embedding: bool | None = True) -> BaseEmbedder:
try:
Expand Down
7 changes: 2 additions & 5 deletions app/workflows/match_entity_records/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@
import streamlit as st

from app.util.openai_wrapper import UIOpenAIConfiguration
from app.util.session_variables import SessionVariables
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder
from toolkit.match_entity_records import config

sv_home = SessionVariables("home")


def embedder() -> BaseEmbedder:
def embedder(local_embedding: bool | None = False) -> BaseEmbedder:
try:
ai_configuration = UIOpenAIConfiguration().get_configuration()
if sv_home.local_embeddings.value:
if local_embedding:
return LocalEmbedder(
db_name=config.cache_name,
max_tokens=ai_configuration.max_tokens,
Expand Down
1 change: 1 addition & 0 deletions app/workflows/match_entity_records/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def create_session(self, prefix):
self.matching_report_validation_messages = SessionVariable("", prefix)
self.matching_system_prompt = SessionVariable(prompts.list_prompts, prefix)
self.matching_upload_key = SessionVariable(random.randint(1, 100), prefix)
self.matching_local_embedding_enabled = SessionVariable(False, prefix)

def reset_workflow(self):
for key in st.session_state:
Expand Down
7 changes: 6 additions & 1 deletion app/workflows/match_entity_records/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def att_ui(i, any_empty, changed, attsaa):
st.rerun()
attributes_list = build_attribute_list(attsa)

local_embedding = st.toggle(
"Use local embeddings",
sv.matching_local_embedding_enabled.value,
help="Use local embeddings to index nodes. If disabled, the model will use OpenAI embeddings.",
)
st.markdown("##### Configure similarity thresholds")
b1, b2 = st.columns([1, 1])
with b1:
Expand Down Expand Up @@ -277,7 +282,7 @@ def on_embedding_batch_change(current, total):
callback = ProgressBatchCallback()
callback.on_batch_change = on_embedding_batch_change

functions_embedder = functions.embedder()
functions_embedder = functions.embedder(local_embedding)
data_embeddings = await functions_embedder.embed_store_many(
all_sentences_data, [callback], sv_home.save_cache.value
)
Expand Down
27 changes: 27 additions & 0 deletions app/workflows/query_text_data/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
# Licensed under the MIT license. See LICENSE file in the project.
#
import streamlit as st

from app.util.openai_wrapper import UIOpenAIConfiguration
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder
from toolkit.query_text_data import config


def embedder(local_embedding: bool | None = False) -> BaseEmbedder:
try:
ai_configuration = UIOpenAIConfiguration().get_configuration()
if local_embedding:
return LocalEmbedder(
db_name=config.cache_name,
max_tokens=ai_configuration.max_tokens,
)
return OpenAIEmbedder(
configuration=ai_configuration,
db_name=config.cache_name,
)
except Exception as e:
st.error(f"Error creating connection: {e}")
st.stop()
2 changes: 2 additions & 0 deletions app/workflows/query_text_data/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def create_session(self, prefix):
self.chunk_progress = SessionVariable("", prefix)
self.answer_progress = SessionVariable("", prefix)

self.answer_local_embedding_enabled = SessionVariable(False, prefix)

def reset_workflow(self):
for key in st.session_state:
if key.startswith(self.prefix):
Expand Down
46 changes: 15 additions & 31 deletions app/workflows/query_text_data/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,25 @@
from seaborn import color_palette
from streamlit_agraph import Config, Edge, Node, agraph

import app.util.example_outputs_ui as example_outputs_ui
import app.workflows.query_text_data.functions as functions
import toolkit.query_text_data.answer_builder as answer_builder
import toolkit.query_text_data.graph_builder as graph_builder
import toolkit.query_text_data.helper_functions as helper_functions
import toolkit.query_text_data.input_processor as input_processor
import toolkit.query_text_data.prompts as prompts
import toolkit.query_text_data.answer_builder as answer_builder
import toolkit.query_text_data.relevance_assessor as relevance_assessor
import toolkit.query_text_data.graph_builder as graph_builder
import app.util.example_outputs_ui as example_outputs_ui
from toolkit.helpers.progress_batch_callback import ProgressBatchCallback
from app.util import ui_components
from app.util.download_pdf import add_download_pdf
from app.util.openai_wrapper import UIOpenAIConfiguration
from app.util.session_variables import SessionVariables
from app.workflows.query_text_data import config
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.defaults import CHUNK_SIZE
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder
from toolkit.graph.graph_fusion_encoder_embedding import (
create_concept_to_community_hierarchy,
generate_graph_fusion_encoder_embedding,
create_concept_to_community_hierarchy
)
from toolkit.helpers.progress_batch_callback import ProgressBatchCallback
from toolkit.query_text_data import config
from toolkit.query_text_data.pattern_detector import (
combine_chunk_text_and_explantion,
detect_converging_pairs,
Expand Down Expand Up @@ -60,27 +58,6 @@ def on_change(current, total):
return pb, callback


def embedder() -> BaseEmbedder:
try:
ai_configuration = UIOpenAIConfiguration().get_configuration()
if sv_home.local_embeddings.value:
return LocalEmbedder(
db_name=config.cache_name,
max_tokens=ai_configuration.max_tokens,
)

return OpenAIEmbedder(
configuration=ai_configuration,
db_name=config.cache_name,
)
except Exception as e:
st.error(f"Error creating connection: {e}")
st.stop()


text_embedder = embedder()


def get_concept_graph(
placeholder, G, concept_to_community, community_to_concepts, width, height, key
):
Expand Down Expand Up @@ -172,6 +149,11 @@ async def create(sv: SessionVariables, workflow=None):
# )
# window_period = input_processor.PeriodOption[window_size]
window_period = input_processor.PeriodOption.NONE
local_embedding = st.toggle(
"Use local embeddings",
sv.answer_local_embedding_enabled.value,
help="Use local embeddings to index nodes. If disabled, the model will use OpenAI embeddings.",
)
if files is not None and st.button("Process files"):
file_pb, file_callback = create_progress_callback(
"Loaded {} of {} files..."
Expand Down Expand Up @@ -250,6 +232,8 @@ async def create(sv: SessionVariables, workflow=None):
embed_pb, embed_callback = create_progress_callback(
"Embedded {} of {} text chunks..."
)
text_embedder = functions.embedder(local_embedding)

sv.cid_to_vector.value = await helper_functions.embed_texts(
sv.cid_to_explained_text.value,
text_embedder,
Expand Down Expand Up @@ -454,7 +438,7 @@ def on_answer(message):
concept_to_community=sv.concept_to_community.value,
previous_cid=sv.previous_cid.value,
next_cid=sv.next_cid.value,
embedder=embedder(),
embedder=functions.embedder(local_embedding),
embedding_cache=sv_home.save_cache.value,
select_logit_bias=5,
adjacent_search_steps=sv.adjacent_chunk_steps.value,
Expand Down
7 changes: 5 additions & 2 deletions toolkit/AI/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@

from toolkit.AI.base_batch_async import BaseBatchAsync
from toolkit.AI.client import OpenAIClient
from toolkit.AI.defaults import DEFAULT_CONCURRENT_COROUTINES
from toolkit.helpers.decorators import retry_with_backoff
from toolkit.helpers.progress_batch_callback import ProgressBatchCallback


class BaseChat(BaseBatchAsync, OpenAIClient):
def __init__(self, configuration=None, concurrent_coroutines=20) -> None:
OpenAIClient.__init__(self, configuration, concurrent_coroutines)
def __init__(
self, configuration=None, concurrent_coroutines=DEFAULT_CONCURRENT_COROUTINES
) -> None:
OpenAIClient.__init__(self, configuration)
self.semaphore = asyncio.Semaphore(concurrent_coroutines)

@retry_with_backoff()
Expand Down
8 changes: 6 additions & 2 deletions toolkit/AI/local_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from sentence_transformers import SentenceTransformer

from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.defaults import DEFAULT_LLM_MAX_TOKENS, DEFAULT_LOCAL_EMBEDDING_MODEL
from toolkit.AI.defaults import (
DEFAULT_CONCURRENT_COROUTINES,
DEFAULT_LLM_MAX_TOKENS,
DEFAULT_LOCAL_EMBEDDING_MODEL,
)
from toolkit.helpers.constants import CACHE_PATH


Expand All @@ -18,7 +22,7 @@ def __init__(
db_name: str = "embeddings",
db_path=CACHE_PATH,
max_tokens=DEFAULT_LLM_MAX_TOKENS,
concurrent_coroutines: int | None = None,
concurrent_coroutines: int | None = DEFAULT_CONCURRENT_COROUTINES,
):
super().__init__(db_name, db_path, max_tokens, concurrent_coroutines)
self.local_client = SentenceTransformer(DEFAULT_LOCAL_EMBEDDING_MODEL)
Expand Down
File renamed without changes.

0 comments on commit fe3c8b3

Please sign in to comment.