Skip to content

Commit

Permalink
disable save embeddings and bug fixes (#82)
Browse files Browse the repository at this point in the history
fixes
  • Loading branch information
dayesouza authored Dec 5, 2024
1 parent 29a06fe commit bf8170e
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 123 deletions.
17 changes: 9 additions & 8 deletions app/components/app_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ def __init__(self, sv=None):
self.sv = SessionVariables("home")

def config(self):
cache = st.sidebar.toggle(
"Save embeddings",
value=self.sv.save_cache.value,
help="Enable caching of embeddings to speed up the application.",
)
if cache != self.sv.save_cache.value:
self.sv.save_cache.value = cache
st.rerun()
print("save embeddings not available")
# cache = st.sidebar.toggle(
# "Save embeddings",
# value=self.sv.save_cache.value,
# help="Enable caching of embeddings to speed up the application.",
# )
# if cache != self.sv.save_cache.value:
# self.sv.save_cache.value = cache
# st.rerun()
34 changes: 26 additions & 8 deletions app/util/ui_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def dataframe_with_selections(df, selections, selection_col, label, key, height=
return selected_rows.drop(label, axis=1)


def report_download_ui(report_var, name):
def report_download_ui(report_var, name) -> None:
if type(report_var) == str:
if len(report_var) == 0:
return
Expand Down Expand Up @@ -89,15 +89,18 @@ def report_download_ui(report_var, name):
)


def generative_ai_component(system_prompt_var, variables):
def generative_ai_component(
system_prompt_var, variables
) -> tuple[bool, list[dict[str, str]], bool]:
st.markdown("##### Generative AI instructions")

with st.expander(
"Edit AI system prompt used to generate output report", expanded=True
):
reset_prompt = st.button("Discard prompt text changes")
instructions_text = st.text_area(
"Prompt text", value=system_prompt_var.value["user_prompt"], height=200
)
reset_prompt = st.button("Reset to default")

st.warning(
"AI outputs may contain errors. Please verify details independently."
Expand All @@ -122,21 +125,21 @@ def generative_ai_component(system_prompt_var, variables):
if ratio <= 100:
st.info(message)
else:
st.warning(message)
st.error(message)
return generate, messages, reset_prompt


def generative_batch_ai_component(
system_prompt_var, variables, batch_name, batch_val, batch_size
):
) -> tuple[bool, list, bool]:
st.markdown("##### Generative AI instructions")
with st.expander("Edit AI System Prompt", expanded=True):
reset_prompt = st.button("Discard prompt text changes")
instructions_text = st.text_area(
"Contents of System Prompt used to generate AI outputs.",
value=system_prompt_var.value["user_prompt"],
height=200,
)
reset_prompt = st.button("Reset to default")

st.warning(
"AI outputs may contain errors. Please verify details independently."
Expand Down Expand Up @@ -182,7 +185,7 @@ def single_csv_uploader(
key,
show_rows=10000,
height=250,
):
) -> None:
if f"{workflow}_uploader_index" not in st.session_state:
st.session_state[f"{workflow}_uploader_index"] = str(random.randint(0, 100))
file = st.file_uploader(
Expand Down Expand Up @@ -799,12 +802,27 @@ def build_validation_ui(
mime="text/json",
)

def check_ai_configuration():
def check_ai_configuration(enforce_structured_output=False):
ai_configuration = UIOpenAIConfiguration().get_configuration()
if ai_configuration.api_key == "":
st.warning("Please set your OpenAI API key in the Settings page.")
if ai_configuration.model == "":
st.warning("Please set your OpenAI model in the Settings page.")

list_enforce_structured_output = [
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
]
if (
enforce_structured_output
and ai_configuration.model not in list_enforce_structured_output
):
st.warning(
"Your current OpenAI model does not support this workflow. Please use the Settings page to use `gpt-4o` or `gpt-4o-mini` as OpenAI Deployment Name."
)


def format_report_group_options(group_dict, existing_groups) -> str:
return " & ".join([f"{key}: {group_dict[key]}" for key in existing_groups])
4 changes: 3 additions & 1 deletion app/workflows/compare_case_groups/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Compare Case Groups

The **Compare Case Groups** workflow generates intelligence reports by defining and comparing groups of case records.
The [`Compare Case Groups`]((https://github.com/microsoft/intelligence-toolkit/blob/main/app/workflows/compare_case_groups/README.md)) workflow generates intelligence reports by defining and comparing groups of case records.

Select the `View example outputs` tab (in app) or navigate to [example_outputs/compare_case_groups](https://github.com/microsoft/intelligence-toolkit/tree/main/example_outputs/compare_case_groups) (on GitHub) for examples.

## How it works

Expand Down
2 changes: 1 addition & 1 deletion app/workflows/detect_entity_networks/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def on_inferring_batch_change(
st.markdown(f"*Number of links inferred*: {inferred_links_count}")
inferred_df = den.inferred_nodes_df()
st.dataframe(
inferred_df.to_pandas(), hide_index=True, use_container_width=True
inferred_df.to_pandas(), hide_index=True, use_container_width=False
)
else:
st.markdown(f"*No inferred links*")
Expand Down
5 changes: 2 additions & 3 deletions app/workflows/extract_record_data/workflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from json import dumps, loads
from json import dumps

import pandas as pd
import streamlit as st
Expand All @@ -8,7 +8,6 @@
import app.util.schema_ui as schema_ui
import app.util.ui_components as ui_components
import app.workflows.extract_record_data.variables as variables
import intelligence_toolkit.extract_record_data.data_extractor as data_extractor
from app.util.download_pdf import add_download_pdf
from app.util.openai_wrapper import UIOpenAIConfiguration

Expand All @@ -21,7 +20,7 @@ def get_intro():


async def create(sv: variables.SessionVariables, workflow: None):
ui_components.check_ai_configuration()
ui_components.check_ai_configuration(enforce_structured_output=True)
erd = sv.workflow_object.value
erd.set_ai_configuration(ai_configuration)
intro_tab, schema_tab, generator_tab, mock_tab = st.tabs(['Extract Record Data workflow:', 'Prepare data schema', 'Extract structured records', 'View example outputs'])
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/generate_mock_data/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_intro():


async def create(sv: bds_variables.SessionVariables, workflow: None):
ui_components.check_ai_configuration()
ui_components.check_ai_configuration(enforce_structured_output=True)
gmd: GenerateMockData = sv.workflow_object.value
gmd.set_ai_configuration(ai_configuration)
intro_tab, schema_tab, record_generator_tab, text_generator_tab, mock_tab = st.tabs(['Generate Mock Data workflow:', 'Prepare data schema', 'Generate mock records', 'Generate mock texts', 'View example outputs'])
Expand Down
168 changes: 87 additions & 81 deletions app/workflows/match_entity_records/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def create(sv: rm_variables.SessionVariable, workflow=None) -> None:
)
entity_id_col = st.selectbox(
"Entity ID column (optional)",
cols,
["", *selected_df.columns],
help="The column containing the unique identifier of the entity to be matched. If left blank, a unique ID will be generated for each entity based on the row number.",
)
filtered_cols = [
Expand Down Expand Up @@ -319,95 +319,101 @@ def on_embedding_batch_change(current, total):
)

with evaluate_tab:
b1, b2 = st.columns([2, 3])
with b1:
batch_size = 100
data = sv.matching_matches_df.value.drop(
[
"Entity ID",
"Dataset",
"Name similarity",
]
).to_pandas()
generate, batch_messages, reset = (
ui_components.generative_batch_ai_component(
sv.matching_system_prompt, {}, "data", data, batch_size
if (
sv.matching_matches_df.value is None
or len(sv.matching_matches_df.value) == 0
):
st.warning("Detect record groups to continue.")
else:
b1, b2 = st.columns([2, 3])
with b1:
batch_size = 100
data = sv.matching_matches_df.value.drop(
[
"Entity ID",
"Dataset",
"Name similarity",
]
).to_pandas()
generate, batch_messages, reset = (
ui_components.generative_batch_ai_component(
sv.matching_system_prompt, {}, "data", data, batch_size
)
)
)
if reset:
sv.matching_system_prompt.value["user_prompt"] = prompts.user_prompt
st.rerun()
with b2:
st.markdown("##### AI evaluation of record groups")
prefix = "```\nGroup ID,Relatedness,Explanation\n"
placeholder = st.empty()
gen_placeholder = st.empty()
if reset:
sv.matching_system_prompt.value["user_prompt"] = prompts.user_prompt
st.rerun()
with b2:
st.markdown("##### AI evaluation of record groups")
prefix = "```\nGroup ID,Relatedness,Explanation\n"
placeholder = st.empty()
gen_placeholder = st.empty()

if generate:
for messages in batch_messages:
callback = ui_components.create_markdown_callback(
placeholder, prefix
)
response = ui_components.generate_text(messages, [callback])
if generate:
for messages in batch_messages:
callback = ui_components.create_markdown_callback(
placeholder, prefix
)
response = ui_components.generate_text(messages, [callback])

if len(response.strip()) > 0:
prefix = prefix + response + "\n"
result = prefix.replace("```\n", "").strip()
sv.matching_evaluations.value = result
lines = result.split("\n")
if len(response.strip()) > 0:
prefix = prefix + response + "\n"
result = prefix.replace("```\n", "").strip()
sv.matching_evaluations.value = result
lines = result.split("\n")

if len(lines) > 30:
lines = lines[:30]
result = "\n".join(lines)
if len(lines) > 30:
lines = lines[:30]
result = "\n".join(lines)

# validation, messages_to_llm = ui_components.validate_ai_report(
# batch_messages[0], result
# )
# sv.matching_report_validation.value = validation
# sv.matching_report_validation_messages.value = messages_to_llm
st.rerun()
else:
if len(sv.matching_evaluations.value) == 0:
gen_placeholder.warning(
"Press the Generate button to create an AI report for the current record matches."
)
placeholder.empty()
# validation, messages_to_llm = ui_components.validate_ai_report(
# batch_messages[0], result
# )
# sv.matching_report_validation.value = validation
# sv.matching_report_validation_messages.value = messages_to_llm
st.rerun()
else:
if len(sv.matching_evaluations.value) == 0:
gen_placeholder.warning(
"Press the Generate button to create an AI report for the current record matches."
)
placeholder.empty()

if len(sv.matching_evaluations.value) > 0:
try:
mer.evaluations_df = pl.read_csv(
io.StringIO(sv.matching_evaluations.value)
)
value = mer.evaluations_df.drop_nulls()
if len(sv.matching_evaluations.value) > 0:
try:
mer.evaluations_df = pl.read_csv(
io.StringIO(sv.matching_evaluations.value)
)
value = mer.evaluations_df.drop_nulls()

st.dataframe(
value.to_pandas(),
height=700,
use_container_width=True,
hide_index=True,
)
c1, c2 = st.columns([1, 1])
with c1:
st.download_button(
"Download AI match reports",
data=value.write_csv(),
file_name="record_group_match_reports.csv",
mime="text/csv",
st.dataframe(
value.to_pandas(),
height=700,
use_container_width=True,
hide_index=True,
)
with c2:
st.download_button(
"Download integrated results",
data=mer.integrated_results.write_csv(),
file_name="integrated_record_match_results.csv",
mime="text/csv",
c1, c2 = st.columns([1, 1])
with c1:
st.download_button(
"Download AI match reports",
data=value.write_csv(),
file_name="record_group_match_reports.csv",
mime="text/csv",
)
with c2:
st.download_button(
"Download integrated results",
data=mer.integrated_results.write_csv(),
file_name="integrated_record_match_results.csv",
mime="text/csv",
)
except:
st.markdown(sv.matching_evaluations.value)
add_download_pdf(
"record_groups_evaluated.pdf",
sv.matching_evaluations.value,
"Download AI match report",
)
except:
st.markdown(sv.matching_evaluations.value)
add_download_pdf(
"record_groups_evaluated.pdf",
sv.matching_evaluations.value,
"Download AI match report",
)

with examples_tab:
example_outputs_ui.create_example_outputs_ui(examples_tab, workflow)
Loading

0 comments on commit bf8170e

Please sign in to comment.