From 783f8fd3622caa7aeaa9b797d4237ff687b7b1af Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Tue, 16 Apr 2024 20:42:48 -0300 Subject: [PATCH] fix risk networks rerun and file download names --- app/util/ui_components.py | 7 ++++--- app/workflows/record_matching/variables.py | 2 ++ app/workflows/record_matching/workflow.py | 23 +++++++++++++++++++++- app/workflows/risk_networks/workflow.py | 14 +++++++++---- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/app/util/ui_components.py b/app/util/ui_components.py index 718c4536..64b017a9 100644 --- a/app/util/ui_components.py +++ b/app/util/ui_components.py @@ -79,7 +79,7 @@ def generative_ai_component(system_prompt_var, variables): with b2: message = f'AI input uses {tokens}/{util.AI_API.max_input_tokens} ({round(ratio, 2)}%) of token limit' if ratio <= 100: - st.success(message) + st.info(message) else: st.warning(message) return generate, messages, reset_prompt @@ -470,7 +470,8 @@ def convert(x): else: st.warning('Generate final dataset to continue.') -def validate_ai_report(messages, result): - st.status('Validating AI report and generating groundedness score...', expanded=False, state='running') +def validate_ai_report(messages, result, show_status = True): + if show_status: + st.status('Validating AI report and generating groundedness score...', expanded=False, state='running') validation, messages_to_llm = util.AI_API.validate_report(messages, result) return re.sub(r"```json\n|\n```", "", validation), messages_to_llm \ No newline at end of file diff --git a/app/workflows/record_matching/variables.py b/app/workflows/record_matching/variables.py index 5d6648c6..0a7af720 100644 --- a/app/workflows/record_matching/variables.py +++ b/app/workflows/record_matching/variables.py @@ -18,4 +18,6 @@ def __init__(self, prefix): self.matching_sentence_pair_embedding_threshold = SessionVariable(0.05, prefix) self.matching_last_sentence_pair_embedding_threshold = SessionVariable(0.05, prefix) self.matching_evaluations = SessionVariable(pl.DataFrame(), prefix) + self.matching_report_validation = SessionVariable({}, prefix) + self.matching_report_validation_messages = SessionVariable('', prefix) self.matching_system_prompt = SessionVariable(prompts.list_prompts, prefix) diff --git a/app/workflows/record_matching/workflow.py b/app/workflows/record_matching/workflow.py index ec4aa8f2..c0cd8575 100644 --- a/app/workflows/record_matching/workflow.py +++ b/app/workflows/record_matching/workflow.py @@ -1,4 +1,6 @@ # Copyright (c) 2024 Microsoft Corporation. All rights reserved. +import json +import pandas as pd import streamlit as st import polars as pl @@ -327,6 +329,8 @@ def att_ui(i): prefix = '```\nGroup ID,Relatedness,Explanation\n' placeholder = st.empty() gen_placeholder = st.empty() + get_current_time = pd.Timestamp.now().strftime('%Y%m%d%H%M%S') + if generate: for messages in batch_messages: response = util.AI_API.generate_text_from_message_list(messages, placeholder, prefix=prefix) @@ -334,6 +338,10 @@ def att_ui(i): prefix = prefix + response + '\n' result = prefix.replace('```\n', '').strip() sv.matching_evaluations.value = pl.read_csv(io.StringIO(result)) + + validation, messages_to_llm = util.ui_components.validate_ai_report(messages, sv.matching_evaluations.value) + sv.matching_report_validation.value = json.loads(validation) + sv.matching_report_validation_messages.value = messages_to_llm st.rerun() else: if len(sv.matching_evaluations.value) == 0: @@ -342,4 +350,17 @@ def att_ui(i): if len(sv.matching_evaluations.value) > 0: st.dataframe(sv.matching_evaluations.value.to_pandas(), height=700, use_container_width=True, hide_index=True) jdf = sv.matching_matches_df.value.join(sv.matching_evaluations.value, on='Group ID', how='inner') - st.download_button('Download AI match report', data=jdf.write_csv(), file_name='record_groups_evaluated.csv', mime='text/csv') \ No newline at end of file + st.download_button('Download AI match report', data=jdf.write_csv(), file_name='record_groups_evaluated.csv', mime='text/csv') + + if sv.matching_report_validation.value != {}: + validation_status = st.status(label=f"LLM faithfulness score: {sv.matching_report_validation.value['score']}/5", state='complete') + with validation_status: + st.write(sv.matching_report_validation.value['explanation']) + + if sv_home.mode.value == 'dev': + obj = json.dumps({ + "message": sv.matching_report_validation_messages.value, + "result": sv.matching_report_validation.value, + "report": sv.matching_evaluations.value + }, indent=4) + st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'matching_{get_current_time}_messages.json', mime='text/json') diff --git a/app/workflows/risk_networks/workflow.py b/app/workflows/risk_networks/workflow.py index 0b95834a..26c0b5d1 100644 --- a/app/workflows/risk_networks/workflow.py +++ b/app/workflows/risk_networks/workflow.py @@ -502,27 +502,33 @@ def create(): gen_placeholder = st.empty() get_current_time = pd.Timestamp.now().strftime('%Y%m%d%H%M%S') + generated = False if generate: + generated = True result = util.AI_API.generate_text_from_message_list( placeholder=report_placeholder, messages=messages, prefix='' ) sv.network_report.value = result + validation_status = st.status('Validating AI report and generating groundedness score...', expanded=False, state='running') - validation, messages_to_llm = util.ui_components.validate_ai_report(messages, result) + validation, messages_to_llm = util.ui_components.validate_ai_report(messages, result, False) sv.network_report_validation.value = json.loads(validation) sv.network_report_validation_messages.value = messages_to_llm - st.rerun() else: if len(sv.network_report.value) == 0: gen_placeholder.warning('Press the Generate button to create an AI report for the current network.') - report_placeholder.markdown(sv.network_report.value) + report_data = sv.network_report.value + report_placeholder.markdown(report_data) util.ui_components.report_download_ui(sv.network_report, 'network_report') if sv.network_report_validation.value != {}: - validation_status = st.status(label=f"LLM faithfulness score: {sv.network_report_validation.value['score']}/5", state='complete') + if generated: + validation_status.update(label=f"LLM faithfulness score: {sv.network_report_validation.value['score']}/5", state='complete') + else: + validation_status = st.status(label=f"LLM faithfulness score: {sv.network_report_validation.value['score']}/5", state='complete') with validation_status: st.write(sv.network_report_validation.value['explanation'])