Skip to content

Commit

Permalink
fix risk networks rerun and file download names
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Apr 16, 2024
1 parent cd4def1 commit 783f8fd
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
7 changes: 4 additions & 3 deletions app/util/ui_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def generative_ai_component(system_prompt_var, variables):
with b2:
message = f'AI input uses {tokens}/{util.AI_API.max_input_tokens} ({round(ratio, 2)}%) of token limit'
if ratio <= 100:
st.success(message)
st.info(message)
else:
st.warning(message)
return generate, messages, reset_prompt
Expand Down Expand Up @@ -470,7 +470,8 @@ def convert(x):
else:
st.warning('Generate final dataset to continue.')

def validate_ai_report(messages, result):
st.status('Validating AI report and generating groundedness score...', expanded=False, state='running')
def validate_ai_report(messages, result, show_status = True):
if show_status:
st.status('Validating AI report and generating groundedness score...', expanded=False, state='running')
validation, messages_to_llm = util.AI_API.validate_report(messages, result)
return re.sub(r"```json\n|\n```", "", validation), messages_to_llm
2 changes: 2 additions & 0 deletions app/workflows/record_matching/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ def __init__(self, prefix):
self.matching_sentence_pair_embedding_threshold = SessionVariable(0.05, prefix)
self.matching_last_sentence_pair_embedding_threshold = SessionVariable(0.05, prefix)
self.matching_evaluations = SessionVariable(pl.DataFrame(), prefix)
self.matching_report_validation = SessionVariable({}, prefix)
self.matching_report_validation_messages = SessionVariable('', prefix)
self.matching_system_prompt = SessionVariable(prompts.list_prompts, prefix)
23 changes: 22 additions & 1 deletion app/workflows/record_matching/workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
import json
import pandas as pd
import streamlit as st
import polars as pl

Expand Down Expand Up @@ -327,13 +329,19 @@ def att_ui(i):
prefix = '```\nGroup ID,Relatedness,Explanation\n'
placeholder = st.empty()
gen_placeholder = st.empty()
get_current_time = pd.Timestamp.now().strftime('%Y%m%d%H%M%S')

if generate:
for messages in batch_messages:
response = util.AI_API.generate_text_from_message_list(messages, placeholder, prefix=prefix)
if len(response.strip()) > 0:
prefix = prefix + response + '\n'
result = prefix.replace('```\n', '').strip()
sv.matching_evaluations.value = pl.read_csv(io.StringIO(result))

validation, messages_to_llm = util.ui_components.validate_ai_report(messages, sv.matching_evaluations.value)
sv.matching_report_validation.value = json.loads(validation)
sv.matching_report_validation_messages.value = messages_to_llm
st.rerun()
else:
if len(sv.matching_evaluations.value) == 0:
Expand All @@ -342,4 +350,17 @@ def att_ui(i):
if len(sv.matching_evaluations.value) > 0:
st.dataframe(sv.matching_evaluations.value.to_pandas(), height=700, use_container_width=True, hide_index=True)
jdf = sv.matching_matches_df.value.join(sv.matching_evaluations.value, on='Group ID', how='inner')
st.download_button('Download AI match report', data=jdf.write_csv(), file_name='record_groups_evaluated.csv', mime='text/csv')
st.download_button('Download AI match report', data=jdf.write_csv(), file_name='record_groups_evaluated.csv', mime='text/csv')

if sv.matching_report_validation.value != {}:
validation_status = st.status(label=f"LLM faithfulness score: {sv.matching_report_validation.value['score']}/5", state='complete')
with validation_status:
st.write(sv.matching_report_validation.value['explanation'])

if sv_home.mode.value == 'dev':
obj = json.dumps({
"message": sv.matching_report_validation_messages.value,
"result": sv.matching_report_validation.value,
"report": sv.matching_evaluations.value
}, indent=4)
st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'matching_{get_current_time}_messages.json', mime='text/json')
14 changes: 10 additions & 4 deletions app/workflows/risk_networks/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,27 +502,33 @@ def create():
gen_placeholder = st.empty()
get_current_time = pd.Timestamp.now().strftime('%Y%m%d%H%M%S')

generated = False
if generate:
generated = True
result = util.AI_API.generate_text_from_message_list(
placeholder=report_placeholder,
messages=messages,
prefix=''
)
sv.network_report.value = result
validation_status = st.status('Validating AI report and generating groundedness score...', expanded=False, state='running')

validation, messages_to_llm = util.ui_components.validate_ai_report(messages, result)
validation, messages_to_llm = util.ui_components.validate_ai_report(messages, result, False)
sv.network_report_validation.value = json.loads(validation)
sv.network_report_validation_messages.value = messages_to_llm
st.rerun()
else:
if len(sv.network_report.value) == 0:
gen_placeholder.warning('Press the Generate button to create an AI report for the current network.')
report_placeholder.markdown(sv.network_report.value)
report_data = sv.network_report.value
report_placeholder.markdown(report_data)

util.ui_components.report_download_ui(sv.network_report, 'network_report')

if sv.network_report_validation.value != {}:
validation_status = st.status(label=f"LLM faithfulness score: {sv.network_report_validation.value['score']}/5", state='complete')
if generated:
validation_status.update(label=f"LLM faithfulness score: {sv.network_report_validation.value['score']}/5", state='complete')
else:
validation_status = st.status(label=f"LLM faithfulness score: {sv.network_report_validation.value['score']}/5", state='complete')
with validation_status:
st.write(sv.network_report_validation.value['explanation'])

Expand Down

0 comments on commit 783f8fd

Please sign in to comment.