Skip to content

Commit

Permalink
add option to change embedding model for both local and openai
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Sep 27, 2024
1 parent 034baea commit 28c5f7e
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 13 deletions.
53 changes: 52 additions & 1 deletion app/pages/Settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,44 @@

import streamlit as st
from components.app_loader import load_multipage_app
from util.constants import MAX_SIZE_EMBEDDINGS_KEY
from util.constants import LOCAL_EMBEDDING_MODEL_KEY, MAX_SIZE_EMBEDDINGS_KEY
from util.enums import Mode
from util.openai_wrapper import (
UIOpenAIConfiguration,
key,
openai_azure_auth_type,
openai_embedding_model,
openai_endpoint_key,
openai_model_key,
openai_type_key,
openai_version_key,
)
from util.secrets_handler import SecretsHandler

from toolkit.AI.defaults import DEFAULT_LOCAL_EMBEDDING_MODEL
from toolkit.AI.vector_store import VectorStore
from toolkit.helpers.constants import CACHE_PATH

openai_embedding_models = [
"text-embedding-3-large",
"text-embedding-3-small",
"text-embedding-ada-002",
]

local_embedding_models = [
"all-mpnet-base-v2",
"multi-qa-mpnet-base-dot-v1",
"all-distilroberta-v1",
"all-MiniLM-L12-v2",
"multi-qa-distilbert-cos-v1",
"paraphrase-multilingual-mpnet-base-v2",
"paraphrase-albert-small-v2",
"paraphrase-multilingual-MiniLM-L12-v2",
"paraphrase-MiniLM-L3-v2",
"distiluse-base-multilingual-cased-v1",
"distiluse-base-multilingual-cased-v2",
]


def on_change(handler, key=None, value=None):
def change():
Expand Down Expand Up @@ -148,6 +170,35 @@ def main():

st.subheader("Embeddings")
max_size = int(secrets_handler.get_secret(MAX_SIZE_EMBEDDINGS_KEY) or 0)
local_embedding_value = (
secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY)
or DEFAULT_LOCAL_EMBEDDING_MODEL
)

st.markdown("Select the embedding model you want to use.")
e1, e2 = st.columns(2)
with e1:
openai_model = st.selectbox(
"OpenAI embedding model",
openai_embedding_models,
index=openai_embedding_models.index(openai_config.embedding_model),
key="openai_embedding_model",
help="Select the embedding model you want to use.",
)
if openai_model != openai_config.embedding_model:
on_change(secrets_handler, openai_embedding_model, openai_model)()
st.rerun()
with e2:
local_model = st.selectbox(
"Local embedding model",
local_embedding_models,
index=local_embedding_models.index(local_embedding_value),
key="local_embedding_model",
help="Select the embedding model you want to use.",
)
if local_model != local_embedding_value:
on_change(secrets_handler, LOCAL_EMBEDDING_MODEL_KEY, local_model)()
st.rerun()

c1, c2 = st.columns(2)
with c1:
Expand Down
1 change: 1 addition & 0 deletions app/util/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
PDF_ENCODING = "UTF-8"
PDF_WKHTMLTOPDF_PATH = "C:\\Program Files\\wkhtmltopdf\\bin\\wkhtmltopdf.exe"
MAX_SIZE_EMBEDDINGS_KEY = "max_embedding"
LOCAL_EMBEDDING_MODEL_KEY = "local_embedding_model"
3 changes: 3 additions & 0 deletions app/util/openai_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
openai_endpoint_key = "openai_endpoint"
openai_model_key = "openai_model"
openai_azure_auth_type = "openai_azure_auth_type"
openai_embedding_model = "openai_embedding_model"


class UIOpenAIConfiguration:
Expand All @@ -26,6 +27,7 @@ def get_configuration(self) -> OpenAIConfiguration:
secret_key = self._secrets.get_secret(key) or None
model = self._secrets.get_secret(openai_model_key) or None
az_auth_type = self._secrets.get_secret(openai_azure_auth_type) or None
embedding_model = self._secrets.get_secret(openai_embedding_model) or None

config = {
"api_type": api_type,
Expand All @@ -34,6 +36,7 @@ def get_configuration(self) -> OpenAIConfiguration:
"api_key": secret_key,
"model": model,
"az_auth_type": az_auth_type,
"embedding_model": embedding_model,
}
values = {k: v for k, v in config.items() if v is not None}
return OpenAIConfiguration(values)
2 changes: 1 addition & 1 deletion app/util/secrets_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_values_from_file(self):

def get_secret(self, value):
values = self.read_values_from_file()
for key in self.read_values_from_file():
for key in values:
if key == value:
return values[key]
return ""
Expand Down
4 changes: 4 additions & 0 deletions app/workflows/detect_entity_networks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from util.openai_wrapper import UIOpenAIConfiguration

import toolkit.detect_entity_networks.config as config
from app.util.constants import LOCAL_EMBEDDING_MODEL_KEY
from app.util.secrets_handler import SecretsHandler
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder
Expand All @@ -13,11 +15,13 @@
def embedder(local_embedding: bool | None = True) -> BaseEmbedder:
try:
ai_configuration = UIOpenAIConfiguration().get_configuration()
secrets_handler = SecretsHandler()
if local_embedding:
return LocalEmbedder(
db_name=config.cache_name,
max_tokens=ai_configuration.max_tokens,
concurrent_coroutines=80,
model=secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY) or None,
)
return OpenAIEmbedder(
configuration=ai_configuration,
Expand Down
13 changes: 7 additions & 6 deletions app/workflows/detect_entity_networks/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,13 @@ async def create(sv: rn_variables.SessionVariables, workflow=None):
help="Select the node types to embed into a multi-dimensional semantic space for fuzzy matching.",
)

total_embeddings = sum(
1
for _, data in sv.network_overall_graph.value.nodes(data=True)
if data["type"] in network_indexed_node_types
)
st.caption(f"Total of {total_embeddings} nodes to index")
if sv.network_overall_graph.value and network_indexed_node_types:
total_embeddings = sum(
1
for _, data in sv.network_overall_graph.value.nodes(data=True)
if data["type"] in network_indexed_node_types
)
st.caption(f"Total of {total_embeddings} nodes to index")
local_embedding = st.toggle(
"Use local embeddings",
sv.network_local_embedding_enabled.value,
Expand Down
4 changes: 4 additions & 0 deletions app/workflows/match_entity_records/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#
import streamlit as st

from app.util.constants import LOCAL_EMBEDDING_MODEL_KEY
from app.util.openai_wrapper import UIOpenAIConfiguration
from app.util.secrets_handler import SecretsHandler
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder
Expand All @@ -13,10 +15,12 @@
def embedder(local_embedding: bool | None = False) -> BaseEmbedder:
try:
ai_configuration = UIOpenAIConfiguration().get_configuration()
secrets_handler = SecretsHandler()
if local_embedding:
return LocalEmbedder(
db_name=config.cache_name,
max_tokens=ai_configuration.max_tokens,
model=secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY) or None,
)
return OpenAIEmbedder(
configuration=ai_configuration,
Expand Down
4 changes: 4 additions & 0 deletions app/workflows/query_text_data/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#
import streamlit as st

from app.util.constants import LOCAL_EMBEDDING_MODEL_KEY
from app.util.openai_wrapper import UIOpenAIConfiguration
from app.util.secrets_handler import SecretsHandler
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder
Expand All @@ -13,10 +15,12 @@
def embedder(local_embedding: bool | None = False) -> BaseEmbedder:
try:
ai_configuration = UIOpenAIConfiguration().get_configuration()
secrets_handler = SecretsHandler()
if local_embedding:
return LocalEmbedder(
db_name=config.cache_name,
max_tokens=ai_configuration.max_tokens,
model=secrets_handler.get_secret(LOCAL_EMBEDDING_MODEL_KEY) or None,
)
return OpenAIEmbedder(
configuration=ai_configuration,
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/query_text_data/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def create(sv: SessionVariables, workflow=None):
sv.cid_to_vector.value = await helper_functions.embed_texts(
sv.cid_to_explained_text.value,
text_embedder,
config.cache_name,
sv_home.save_cache.value,
callbacks=[embed_callback],
)
chunk_pb.empty()
Expand Down
3 changes: 2 additions & 1 deletion toolkit/AI/local_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def __init__(
db_path=CACHE_PATH,
max_tokens=DEFAULT_LLM_MAX_TOKENS,
concurrent_coroutines: int | None = DEFAULT_CONCURRENT_COROUTINES,
model: str | None = DEFAULT_LOCAL_EMBEDDING_MODEL,
):
super().__init__(db_name, db_path, max_tokens, concurrent_coroutines)
self.local_client = SentenceTransformer(DEFAULT_LOCAL_EMBEDDING_MODEL)
self.local_client = SentenceTransformer(model)

def _generate_embedding(self, text: str | list[str]) -> list | Any:
return self.local_client.encode(text).tolist()
Expand Down
12 changes: 12 additions & 0 deletions toolkit/AI/openai_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .defaults import (
DEFAULT_AZ_AUTH_TYPE,
DEFAULT_EMBEDDING_MODEL,
DEFAULT_LLM_MAX_TOKENS,
DEFAULT_LLM_MODEL,
DEFAULT_OPENAI_VERSION,
Expand Down Expand Up @@ -33,6 +34,7 @@ class OpenAIConfiguration:
_max_tokens: int | None
_api_type: str
_az_auth_type: str
_embedding_model: str

def __init__(
self,
Expand All @@ -53,6 +55,9 @@ def __init__(
self._max_tokens = config.get("max_tokens", DEFAULT_LLM_MAX_TOKENS)
self._az_auth_type = config.get("az_auth_type", self._get_az_auth_type())
self._api_type = config.get("api_type", oai_type)
self._embedding_model = config.get(
"embedding_model", self._get_embedding_model()
)

def _get_openai_type(self):
return os.environ.get("OPENAI_TYPE", "OpenAI")
Expand All @@ -66,6 +71,9 @@ def _get_azure_openai_version(self):
def _get_chat_model(self):
return os.environ.get("OPENAI_API_MODEL", DEFAULT_LLM_MODEL)

def _get_embedding_model(self):
return os.environ.get("OPENAI_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)

def _get_azure_api_base(self):
return os.environ.get("AZURE_OPENAI_ENDPOINT", "")

Expand Down Expand Up @@ -104,6 +112,10 @@ def max_tokens(self) -> int | None:
"""Max tokens property definition."""
return self._max_tokens

@property
def embedding_model(self) -> str | None:
return self._embedding_model

@property
def api_type(self) -> str | None:
"""Type of the AI connection."""
Expand Down
8 changes: 6 additions & 2 deletions toolkit/AI/openai_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def __init__(
self.openai_client = OpenAIClient(configuration)

def _generate_embedding(self, text: str) -> list[float]:
return self.openai_client.generate_embedding(text)
return self.openai_client.generate_embedding(
text, model=self.configuration.embedding_model
)

async def _generate_embedding_async(self, text: str) -> list[float]:
return await self.openai_client.generate_embedding_async(text)
return await self.openai_client.generate_embedding_async(
text, model=self.configuration.embedding_model
)
2 changes: 1 addition & 1 deletion toolkit/detect_entity_networks/index_and_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def index_nodes(
if data["type"] in indexed_node_types
]
text_types.sort()
texts = [t[0] for t in text_types]
texts = [text_type[0] for text_type in text_types]

data = [
{
Expand Down

0 comments on commit 28c5f7e

Please sign in to comment.