diff --git a/app/workflows/detect_case_patterns/workflow.py b/app/workflows/detect_case_patterns/workflow.py index 03860537..2d9b61f2 100644 --- a/app/workflows/detect_case_patterns/workflow.py +++ b/app/workflows/detect_case_patterns/workflow.py @@ -242,12 +242,7 @@ def create(sv: ap_variables.SessionVariables, workflow): tdf = dcp.time_series_df tdf = tdf[tdf["pattern"] == selected_pattern] sv.detect_case_patterns_selected_pattern_df.value = tdf - sv.detect_case_patterns_selected_pattern_att_counts.value = ( - dcp.compute_attribute_counts( - selected_pattern, - time_col, - ) - ) + count_ct = dcp.create_time_series_chart( selected_pattern, selected_pattern_period @@ -266,6 +261,12 @@ def create(sv: ap_variables.SessionVariables, workflow): else: c1, c2 = st.columns([2, 3]) with c1: + sv.detect_case_patterns_selected_pattern_att_counts.value = ( + dcp.compute_attribute_counts( + selected_pattern, + selected_pattern_period, + ) + ) variables = { "pattern": sv.detect_case_patterns_selected_pattern.value, "period": sv.detect_case_patterns_selected_pattern_period.value, diff --git a/app/workflows/query_text_data/workflow.py b/app/workflows/query_text_data/workflow.py index c32f80fd..aba5f417 100644 --- a/app/workflows/query_text_data/workflow.py +++ b/app/workflows/query_text_data/workflow.py @@ -162,62 +162,68 @@ async def create(sv: SessionVariables, workflow=None): st.warning(f"Process files to continue.") else: with st.expander("Options", expanded=False): - c1, c2, c3, c4, c5, c6, c7 = st.columns(7) - with c1: - st.number_input( - "Relevance test budget", - value=sv.relevance_test_budget.value, - key=sv.relevance_test_budget.key, - min_value=0, - help="The query method works by asking an LLM to evaluate the relevance of potentially-relevant text chunks, returning a single token, yes/no judgement. This parameter allows the user to cap the number of relvance tests that may be performed prior to generating an answer using all relevant chunks. Larger budgets will generally give better answers for a greater cost." - ) - with c2: - st.number_input( - "Tests/topic/round", - value=sv.relevance_test_batch_size.value, - key=sv.relevance_test_batch_size.key, - min_value=0, - help="How many relevant tests to perform for each topic in each round. Larger values reduce the likelihood of prematurely discarding topics whose relevant chunks may not be at the top of the similarity-based ranking, but may result in smaller values of `Relevance test budget` being spread across fewer topics and thus not capturing the full breadth of the data." - ) - with c3: - st.number_input( - "Restart on irrelevant topics", - value=sv.irrelevant_community_restart.value, - key=sv.irrelevant_community_restart.key, - min_value=0, - help="When this number of topics in a row fail to return any relevant chunks in their `Tests/topic/round`, return to the start of the topic ranking and continue testing `Tests/topic/round` text chunks from each topic with (a) relevance in the previous round and (b) previously untested text chunks. Higher values can avoid prematurely discarding topics that are relevant but whose relevant chunks are not at the top of the similarity-based ranking, but may result in a larger number of irrelevant topics being tested multiple times." - ) - with c4: - st.number_input( - "Test relevant neighbours", - value=sv.adjacent_test_steps.value, - key=sv.adjacent_test_steps.key, - min_value=0, - help="If a text chunk is relevant to the query, then adjacent text chunks in the original document may be able to add additional context to the relevant points. The value of this parameter determines how many chunks before and after each relevant text chunk will be evaluated at the end of the process (or `Relevance test budget`) if they are yet to be tested." - ) - with c5: - st.number_input( - "Target chunks per cluster", - value=sv.target_chunks_per_cluster.value, - key=sv.target_chunks_per_cluster.key, - min_value=0, - help="The average number of text chunks to target per cluster, which determines the text chunks that will be evaluated together and in parallel to other clusters. Larger values will generally result in more related text chunks being evaluated in parallel, but may also result in information loss from unprocessed content." - ) - with c6: - st.radio( - label="Evidence type", - options=["Source text", "Extracted claims"], - key=sv.search_type.key, - help="If the evidence type is set to 'Source text', the system will generate an answer directly from the text chunks. If the search type is set to 'Extracted claims', the system will extract claims from the text chunks and generate an answer based on the extracted claims in addition to the source text.", - ) - with c7: - st.number_input( - "Claim search depth", - value=sv.claim_search_depth.value, - key=sv.claim_search_depth.key, - min_value=0, - help="If the evidence type is set to 'Extracted claims', this parameter sets the number of most-similar text chunks to analyze for each extracted claim, looking for both supporting and contradicting evidence." - ) + cl, cr = st.columns([5, 2]) + with cl: + st.markdown("**Search options**") + c1, c2, c3, c4, c5 = st.columns(5) + with c1: + st.number_input( + "Relevance test budget", + value=sv.relevance_test_budget.value, + key=sv.relevance_test_budget.key, + min_value=0, + help="The query method works by asking an LLM to evaluate the relevance of potentially-relevant text chunks, returning a single token, yes/no judgement. This parameter allows the user to cap the number of relvance tests that may be performed prior to generating an answer using all relevant chunks. Larger budgets will generally give better answers for a greater cost." + ) + with c2: + st.number_input( + "Tests/topic/round", + value=sv.relevance_test_batch_size.value, + key=sv.relevance_test_batch_size.key, + min_value=0, + help="How many relevant tests to perform for each topic in each round. Larger values reduce the likelihood of prematurely discarding topics whose relevant chunks may not be at the top of the similarity-based ranking, but may result in smaller values of `Relevance test budget` being spread across fewer topics and thus not capturing the full breadth of the data." + ) + with c3: + st.number_input( + "Restart on irrelevant topics", + value=sv.irrelevant_community_restart.value, + key=sv.irrelevant_community_restart.key, + min_value=0, + help="When this number of topics in a row fail to return any relevant chunks in their `Tests/topic/round`, return to the start of the topic ranking and continue testing `Tests/topic/round` text chunks from each topic with (a) relevance in the previous round and (b) previously untested text chunks. Higher values can avoid prematurely discarding topics that are relevant but whose relevant chunks are not at the top of the similarity-based ranking, but may result in a larger number of irrelevant topics being tested multiple times." + ) + with c4: + st.number_input( + "Test relevant neighbours", + value=sv.adjacent_test_steps.value, + key=sv.adjacent_test_steps.key, + min_value=0, + help="If a text chunk is relevant to the query, then adjacent text chunks in the original document may be able to add additional context to the relevant points. The value of this parameter determines how many chunks before and after each relevant text chunk will be evaluated at the end of the process (or `Relevance test budget`) if they are yet to be tested." + ) + with c5: + st.number_input( + "Target chunks per cluster", + value=sv.target_chunks_per_cluster.value, + key=sv.target_chunks_per_cluster.key, + min_value=0, + help="The average number of text chunks to target per cluster, which determines the text chunks that will be evaluated together and in parallel to other clusters. Larger values will generally result in more related text chunks being evaluated in parallel, but may also result in information loss from unprocessed content." + ) + with cr: + st.markdown("**Answer options**") + c6, c7 = st.columns([1, 1]) + with c6: + st.radio( + label="Evidence type", + options=["Source text", "Extracted claims"], + key=sv.search_type.key, + help="If the evidence type is set to 'Source text', the system will generate an answer directly from the text chunks. If the search type is set to 'Extracted claims', the system will extract claims from the text chunks and generate an answer based on the extracted claims in addition to the source text.", + ) + with c7: + st.number_input( + "Claim search depth", + value=sv.claim_search_depth.value, + key=sv.claim_search_depth.key, + min_value=0, + help="If the evidence type is set to 'Extracted claims', this parameter sets the number of most-similar text chunks to analyze for each extracted claim, looking for both supporting and contradicting evidence." + ) c1, c2 = st.columns([6, 1]) with c1: st.text_input( diff --git a/toolkit/detect_case_patterns/model.py b/toolkit/detect_case_patterns/model.py index eb2a75a3..7ddab48f 100644 --- a/toolkit/detect_case_patterns/model.py +++ b/toolkit/detect_case_patterns/model.py @@ -46,11 +46,11 @@ def generate_graph_model(df, period_col, type_val_sep): def compute_attribute_counts(df, pattern, period_col, period, type_val_sep): + print(f"Computing attribute counts for pattern: {pattern} with period: {period} for period column: {period_col}") atts = pattern.split(" & ") # Combine astype and replace operations fdf = df_functions.fix_null_ints(df) fdf = fdf[fdf[period_col] == period] - # Pre-filter columns to avoid unnecessary processing relevant_columns = [c for c in fdf.columns if c not in ["Subject ID", period_col]] # fdf = fdf[["Subject ID", period_col, *relevant_columns]] @@ -73,14 +73,15 @@ def compute_attribute_counts(df, pattern, period_col, period, type_val_sep): ) melted = melted[melted["Value"] != ""] melted["AttributeValue"] = melted["Attribute"] + type_val_sep + melted["Value"] - + print(melted) # Directly use nunique in groupby - return ( + count_df = ( melted.groupby("AttributeValue")["Subject ID"] .nunique() .reset_index(name="Count") .sort_values(by="Count", ascending=False) ) + return count_df def create_time_series_df(model, pattern_df): diff --git a/toolkit/query_text_data/answer_builder.py b/toolkit/query_text_data/answer_builder.py index 5dec060e..09e52254 100644 --- a/toolkit/query_text_data/answer_builder.py +++ b/toolkit/query_text_data/answer_builder.py @@ -3,7 +3,6 @@ import re from json import loads, dumps -import numpy as np import asyncio import string from tqdm.asyncio import tqdm_asyncio @@ -14,23 +13,32 @@ import toolkit.query_text_data.answer_schema as answer_schema import toolkit.query_text_data.prompts as prompts import sklearn.cluster as cluster -import scipy.spatial from sklearn.neighbors import NearestNeighbors from toolkit.query_text_data.classes import AnswerObject -def extract_chunk_references(text): +def extract_and_link_chunk_references(text, link=True): source_spans = list(re.finditer(r'\[source: ([^\]]+)\]', text, re.MULTILINE)) references = set() for source_span in source_spans: - parts = source_span.group(1).split(', ') - references.update(parts) - ref_list = sorted(references) - return ref_list + old_span = source_span.group(0) + parts = [x.strip() for x in source_span.group(1).split(',')] + matched_parts = [x for x in parts if re.match(r'^\d+$', x)] + references.update(matched_parts) + if link: + new_span = source_span.group(0) + for part in matched_parts: + new_span = new_span.replace(part, f"[{part}](#source-{part})") + text = text.replace(old_span, new_span) + references = [int(cid) for cid in references if cid.isdigit()] + references = sorted(references) + return text, references -def link_chunk_references(text, references): - for ix, reference in enumerate(references): - text = text.replace(reference, f"[{ix+1}](#source-{ix+1})") - return text +def create_cid_to_label(processed_chunks): + cid_to_label = {} + for cid, text in enumerate(processed_chunks.cid_to_text.values()): + chunk = loads(text) + cid_to_label[cid] = f"{chunk['title']} ({chunk['chunk_id']})" + return cid_to_label async def answer_question( ai_configuration, @@ -42,6 +50,7 @@ async def answer_question( embedding_cache, answer_config, ): + cid_to_label = create_cid_to_label(processed_chunks) target_clusters = len(relevant_cids) // answer_config.target_chunks_per_cluster if len(relevant_cids) / answer_config.target_chunks_per_cluster > target_clusters: target_clusters += 1 @@ -50,8 +59,7 @@ async def answer_question( cid_to_vector, target_clusters ) - clustered_texts = [[processed_chunks.cid_to_text[cid] for cid in cids] for cids in clustered_cids.values()] - source_to_text = {f"{text['title']} ({text['chunk_id']})": text for text in [loads(text) for text in processed_chunks.cid_to_text.values()]} + clustered_texts = [[f"{cid}: {processed_chunks.cid_to_text[cid]}" for cid in cids] for cids in clustered_cids.values()] source_to_supported_claims = defaultdict(set) source_to_contradicted_claims = defaultdict(set) net_new_sources = 0 @@ -91,25 +99,15 @@ async def answer_question( claim_to_vector = {claims_to_embed[cix]: vector for cix, vector in cix_to_vector.items()} units = sorted([(cid, vector) for cid, vector in (cid_to_vector.items())], key=lambda x: x[0]) neighbours = NearestNeighbors(n_neighbors=answer_config.claim_search_depth, metric='cosine').fit([vector for cid, vector in units]) - cix = 0 + for claim_sets in json_extracted_claims: for claim_set in claim_sets['claim_analysis']: claim_context = claim_set['claim_context'] for claim in claim_set['claims']: claim_statement = claim['claim_statement'] claim_key = f"{claim_statement} (context: {claim_context})" - supporting_sources = set() - for ss in claim['supporting_sources']: - tt = ss['text_title'] - for sc in ss['chunk_ids']: - supporting_sources.add(f"{tt} ({sc})") - contradicting_sources = set() - for cs in claim['contradicting_sources']: - tt = cs['text_title'] - for sc in cs['chunk_ids']: - contradicting_sources.add(f"{tt} ({sc})") - claim_context_to_claim_supporting_sources[claim_context][claim_statement] = supporting_sources - claim_context_to_claim_contradicting_sources[claim_context][claim_statement] = contradicting_sources + claim_context_to_claim_supporting_sources[claim_context][claim_statement] = set(claim['supporting_sources']) + claim_context_to_claim_contradicting_sources[claim_context][claim_statement] = set(claim['contradicting_sources']) if claim_key in claim_to_vector: tasks.append(asyncio.create_task(requery_claim( ai_configuration, @@ -119,7 +117,8 @@ async def answer_question( claim_to_vector[claim_key], claim_context, claim_statement, - processed_chunks.cid_to_text + processed_chunks.cid_to_text, + cid_to_label ))) else: print(f'No vector for claim: {claim_key}') @@ -159,14 +158,22 @@ async def answer_question( } ], 'sources': { - source: text for source, text in source_to_text.items() if source in supporting_sources.union(contradicting_sources) + source: text for source, text in processed_chunks.cid_to_text.items() if source in supporting_sources.union(contradicting_sources) } } ) - batched_summarization_messages = [utils.prepare_messages(prompts.claim_summarization_prompt, {'analysis': dumps(claims, ensure_ascii=False, indent=2), 'data': '', 'question': question}) + def extract_relevant_chunks(claims): + relevant_cids = set() + for claim in claims['claims']: + ss = set(claim['supporting_sources']) + cs = set(claim['contradicting_sources']) + relevant_cids.update(ss.union(cs)) + relevant_cids = sorted(relevant_cids) + return [f"{cid}: {processed_chunks.cid_to_text[cid]}" for cid in relevant_cids] + batched_summarization_messages = [utils.prepare_messages(prompts.claim_summarization_prompt, {'analysis': dumps(claims, ensure_ascii=False, indent=2), 'chunks': extract_relevant_chunks(claims), 'question': question}) for i, claims in enumerate(claim_summaries)] else: - batched_summarization_messages = [utils.prepare_messages(prompts.claim_summarization_prompt, {'analysis': [], 'data': clustered_texts[i], 'question': question}) + batched_summarization_messages = [utils.prepare_messages(prompts.claim_summarization_prompt, {'analysis': [], 'chunks': clustered_texts[i], 'question': question}) for i in range(len(clustered_texts))] summarized_claims = await utils.map_generate_text( @@ -209,23 +216,24 @@ def build_report_markdown(question, content_items_dict, content_structure, cid_t report += f'##### {item["content_title"]}\n\n{item["content_summary"]}\n\n{item["content_commentary"]}\n\n' report += f'##### AI theme commentary\n\n{theme["theme_commentary"]}\n\n' report += f'#### AI report commentary\n\n{content_structure["report_commentary"]}\n\n' - references = extract_chunk_references(report) - report = link_chunk_references(report, references) + report, references = extract_and_link_chunk_references(report) + print(f'Extracted references: {references}') report += f'## Sources\n\n' - for ix, source_label in enumerate(references): - if source_label in matched_chunks: - supports_claims = source_to_supported_claims[source_label] - contradicts_claims = source_to_contradicted_claims[source_label] + for cid in references: + if cid in cid_to_text.keys(): + supports_claims = source_to_supported_claims[cid] + contradicts_claims = source_to_contradicted_claims[cid] supports_claims_str = '- ' + '\n- '.join([claim_statement for _, claim_statement in supports_claims]) contradicts_claims_str = '- ' + '\n- '.join([claim_statement for _, claim_statement in contradicts_claims]) - report += f'#### Source {ix+1}\n\n
\n\n##### Text chunk: {source_label}\n\n{matched_chunks[source_label]["text_chunk"]}\n\n' + chunk = loads(cid_to_text[cid]) + report += f'#### Source {cid}\n\n
\n\n##### Text chunk: {chunk["title"]}: {chunk["chunk_id"]}\n\n{chunk["text_chunk"]}\n\n' if len(supports_claims) > 0: report += f'##### Supports claims\n\n{supports_claims_str}\n\n' if len(contradicts_claims) > 0: report += f'##### Contradicts claims\n\n{contradicts_claims_str}\n\n' report += f'
\n\n[Back to top]({home_link})\n\n' else: - print(f'No match for {source_label}') + print(f'No match for {cid}') return report, references, matched_chunks @@ -250,7 +258,7 @@ def cluster_cids( clustered_cids[cluster_assignment].append(cid) return clustered_cids -async def requery_claim(ai_configuration, question, units, neighbours, claim_embedding, claim_context, claim_statement, cid_to_text): +async def requery_claim(ai_configuration, question, units, neighbours, claim_embedding, claim_context, claim_statement, cid_to_text, cid_to_label): contextualized_claim = f"{claim_statement} (context: {claim_context})" # Find the nearest neighbors of the claim embedding indices = neighbours.kneighbors([claim_embedding], return_distance=False) @@ -265,17 +273,11 @@ async def requery_claim(ai_configuration, question, units, neighbours, claim_emb # ) # batch cids into batches of size batch_size # cids = [cid for cid, dist in cosine_distances[:search_depth]] - chunks = dumps({i: cid_to_text[cid] for i, cid in enumerate(cids)}, ensure_ascii=False, indent=2) + chunks = dumps({cid: cid_to_text[cid] for i, cid in enumerate(cids)}, ensure_ascii=False, indent=2) messages = utils.prepare_messages(prompts.claim_requery_prompt, {'question': question, 'claim': contextualized_claim, 'chunks': chunks}) response = await utils.generate_text_async(ai_configuration, messages, response_format=answer_schema.claim_requery_format) - chunk_titles = [] - for cid in cids: - text_json = loads(cid_to_text[cid]) - chunk_titles.append(f"{text_json['title']} ({text_json['chunk_id']})") response_json = loads(response) - supporting_sources = response_json['supporting_source_indicies'] - contradicting_sources = response_json['contradicting_source_indicies'] - supporting_source_labels = [chunk_titles[i] for i in supporting_sources] - contradicting_source_labels = [chunk_titles[i] for i in contradicting_sources] - return claim_context, claim_statement, supporting_source_labels, contradicting_source_labels \ No newline at end of file + supporting_sources = response_json['supporting_sources'] + contradicting_sources = response_json['contradicting_sources'] + return claim_context, claim_statement, supporting_sources, contradicting_sources \ No newline at end of file diff --git a/toolkit/query_text_data/answer_schema.py b/toolkit/query_text_data/answer_schema.py index 623c96f6..1a668132 100644 --- a/toolkit/query_text_data/answer_schema.py +++ b/toolkit/query_text_data/answer_schema.py @@ -81,39 +81,13 @@ "supporting_sources": { "type": "array", "items": { - "type": "object", - "properties": { - "text_title": { - "type": "string" - }, - "chunk_ids": { - "type": "array", - "items": { - "type": "number" - } - } - }, - "required": ["text_title", "chunk_ids"], - "additionalProperties": False, + "type": "number" } }, "contradicting_sources": { "type": "array", "items": { - "type": "object", - "properties": { - "text_title": { - "type": "string" - }, - "chunk_ids": { - "type": "array", - "items": { - "type": "number" - } - } - }, - "required": ["text_title", "chunk_ids"], - "additionalProperties": False, + "type": "number" } } }, @@ -241,20 +215,20 @@ "type": "object", "properties": { - "supporting_source_indicies": { + "supporting_sources": { "type": "array", "items": { "type": "number", } }, - "contradicting_source_indicies": { + "contradicting_sources": { "type": "array", "items": { "type": "number", } } }, - "required": ["supporting_source_indicies", "contradicting_source_indicies"], + "required": ["supporting_sources", "contradicting_sources"], "additionalProperties": False, } } diff --git a/toolkit/query_text_data/prompts.py b/toolkit/query_text_data/prompts.py index f3878bce..4de5122d 100644 --- a/toolkit/query_text_data/prompts.py +++ b/toolkit/query_text_data/prompts.py @@ -63,10 +63,10 @@ Given a question, the output object should extract claims from the input text chunks as follows: - "claim_context": an overall description of the context in which claims are made -- "claim_statement": a statement-based formatting of a claim that is relevant to the user question +- "claim_statement": a statement-based formatting of a claim that is relevant to the user question and includes relevant contextual information (e.g., time and place) - "claim_attribution": any named source or author of a claim, beyond the title of the text -- "text_title": the title of the text from which the chunk was exracted -- "chunk_id": the id of the chunk within the text +- "supporting_sources": a list of source IDs that support the claim +- "contradicting_sources": a list of source IDs that contradict the claim --TASK-- @@ -74,7 +74,7 @@ {question} -Input text chunks JSON: +Input text chunks JSON, in the form ": ": {chunks} @@ -84,19 +84,19 @@ claim_summarization_prompt = """\ You are a helpful assistant tasked with creating a JSON object that summarizes claims relevant to a given user question. +When presenting source evidence, support each sentence with a source reference to the file and text chunk: "[source: , ]. + The output object should summarize all claims from input text chunks as follows: - "content_title": a title for a specific content item spanning related claims, in the form of a derived claim statement - "content_summary": a paragraph, starting with "**Source evidence**:", describing each of the individual claims and the balance of evidence supporting or contradicting them - "content_commentary": a paragraph, starting with "**AI commentary**:", suggesting inferences, implications, or conclusions that could be drawn from the source evidence -When presenting source evidence, support each sentence with a source reference to the file and text chunk: "[source: (), ()]. Always use the full name of the file - do not abbreviate - and enter the full filename before each chunk id, even if the same file contains multiple relevant chunks. - --TASK-- -Input text chunks: +Input text chunks JSON, in the form ": ": -{data} +{chunks} Input claim analysis JSON: @@ -121,7 +121,7 @@ - "theme_commentary": a concluding paragraph that summarizes the content items in the theme and their relevance to the user question, with additional interpretation - "report_commentary": a concluding paragraph that summarizes the themes and their relevance to the user question, with additional interpretation -When presenting evidence, support each sentence with a source reference to the file and text chunk: "[source: (), ()]. Always use the full name of the file - do not abbreviate - and enter the full filename before each chunk id, even if the same file contains multiple relevant chunks. +When presenting evidence, support each sentence with one or more source references: "[source: , ]. --TASK-- @@ -141,8 +141,8 @@ The output object should summarize all claims from input text chunks as follows: -- "supporting_source_indicies": the indices of the input text chunks that support the claim (starting at 0) -- "contradicting_source_indicies": the indices of the input text chunks that contradict the claim (starting at 0) +- "supporting_source": the IDs of the input text chunks that support the claim +- "contradicting_sources": the IDs of the input text chunks that contradict the claim --TASK-- @@ -150,7 +150,7 @@ {claim} -Input text chunks JSON: +Input text chunks JSON, in the form ": ": {chunks} diff --git a/toolkit/query_text_data/relevance_assessor.py b/toolkit/query_text_data/relevance_assessor.py index 9f71ffdb..95ac3ca1 100644 --- a/toolkit/query_text_data/relevance_assessor.py +++ b/toolkit/query_text_data/relevance_assessor.py @@ -229,8 +229,8 @@ async def detect_relevant_chunks( eliminated_communities.add(community) successive_irrelevant += 1 if successive_irrelevant == chunk_search_config.irrelevant_community_restart: - successive_irrelevant = 0 print(f'{successive_irrelevant} successive irrelevant communities; restarting') + successive_irrelevant = 0 break else: successive_irrelevant = 0