Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Darren Edge committed Nov 11, 2024
2 parents 1bb56f2 + 9ae6a32 commit 1d3ec9e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
7 changes: 6 additions & 1 deletion app/util/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from app.util.openai_wrapper import UIOpenAIConfiguration
from app.util.secrets_handler import SecretsHandler
from intelligence_toolkit.AI.base_embedder import BaseEmbedder
from intelligence_toolkit.AI.defaults import DEFAULT_CONCURRENT_COROUTINES
from intelligence_toolkit.AI.local_embedder import LocalEmbedder
from intelligence_toolkit.AI.openai_embedder import OpenAIEmbedder
from intelligence_toolkit.query_text_data import config


def create_embedder(local_embedding: bool | None = False) -> BaseEmbedder:
def create_embedder(
local_embedding: bool | None = False,
concurrent_coroutines: int = DEFAULT_CONCURRENT_COROUTINES,
) -> BaseEmbedder:
try:
ai_configuration = UIOpenAIConfiguration().get_configuration()
secrets_handler = SecretsHandler()
Expand All @@ -20,6 +24,7 @@ def create_embedder(local_embedding: bool | None = False) -> BaseEmbedder:
return OpenAIEmbedder(
configuration=ai_configuration,
db_name=config.cache_name,
concurrent_coroutines=concurrent_coroutines,
)
except Exception as e:
print(f"Error creating connection: {e}")
3 changes: 2 additions & 1 deletion app/workflows/detect_case_patterns/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def create(sv: ap_variables.SessionVariables, workflow):
if sv.detect_case_patterns_selected_pattern.value != "":
count_ct = dcp.create_time_series_chart(
sv.detect_case_patterns_selected_pattern.value,
sv.detect_case_patterns_selected_pattern_period.value
sv.detect_case_patterns_selected_pattern_period.value,
True,
)
st.altair_chart(count_ct, use_container_width=True)
report_placeholder = st.empty()
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/query_text_data/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def create(sv: SessionVariables, workflow=None):
value=sv.answer_local_embedding_enabled.value,
help="Use local embeddings to index nodes. If disabled, the model will use OpenAI embeddings.",
)
qtd.set_embedder(embedder.create_embedder(local_embedding))
qtd.set_embedder(embedder.create_embedder(local_embedding, 20))

if st.button("Process files") and (
files is not None or file_chunks is not None
Expand Down
17 changes: 17 additions & 0 deletions intelligence_toolkit/detect_case_patterns/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,28 @@ def create_time_series_chart(
self,
selected_pattern,
selected_pattern_period,
resize_title=False,
):
selected_pattern_df = self.time_series_df[
(self.time_series_df["pattern"] == selected_pattern)
]
title = "Pattern: " + selected_pattern + " (" + selected_pattern_period + ")"
if resize_title and len(title) > 100:
# Find the last occurrence of '&' within the first half of the title
split_index = title.rfind("&", 0, len(title) // 2)

# If '&' is found, break the title there; otherwise, split by length
if split_index != -1:
title = [
title[: split_index + 1].strip(),
title[split_index + 1 :].strip(),
]
else:
title = [
title[: len(title) // 2].strip(),
title[len(title) // 2 :].strip(),
]

count_ct = (
alt.Chart(selected_pattern_df)
.mark_line()
Expand Down

0 comments on commit 1d3ec9e

Please sign in to comment.