Skip to content

Commit

Permalink
Mer api (microsoft#67)
Browse files Browse the repository at this point in the history
* add option getter for filtering

* add mer api

* add api funcs

* remove print

* fix tests

* fix api

* mer results

* fix imports
  • Loading branch information
dayesouza authored and scrt-dev committed Oct 30, 2024
1 parent e57affb commit ab24e78
Show file tree
Hide file tree
Showing 32 changed files with 4,170 additions and 1,197 deletions.
2 changes: 0 additions & 2 deletions app/util/ui_components.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
import json
import math
import os
import random
import re
import sys
from collections import defaultdict
from typing import Any

import numpy as np
Expand Down
25 changes: 3 additions & 22 deletions app/workflows/compare_case_groups/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,8 @@ def create(sv: gn_variables.SessionVariables, workflow=None):
c1, c2 = st.columns([1, 2])
with c1:
st.markdown("##### Define summary model")
sorted_atts = []
sorted_cols = sorted(sv.case_groups_final_df.value.columns)

for col in sorted_cols:
vals = [
f"{col}:{x}"
for x in sorted(
sv.case_groups_final_df.value[col].astype(str).unique()
)
if x
not in [
"",
"<NA>",
"nan",
"NaN",
"None",
"none",
"NULL",
"null",
]
]
sorted_atts.extend(vals)

groups = st.multiselect(
"Compare groups of records with different combinations of these attributes:",
sorted_cols,
Expand All @@ -109,7 +88,9 @@ def create(sv: gn_variables.SessionVariables, workflow=None):
)
filters = st.multiselect(
"After filtering to records matching these values (optional):",
sorted_atts,
ccg.get_filter_options(
pl.from_pandas(sv.case_groups_final_df.value)
),
default=sv.case_groups_filters.value,
)

Expand Down
3 changes: 2 additions & 1 deletion app/workflows/detect_case_patterns/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

import app.util.example_outputs_ui as example_outputs_ui
import app.workflows.detect_case_patterns.variables as ap_variables
import toolkit.detect_case_patterns.config as config
from app.util import ui_components
from app.util.download_pdf import add_download_pdf
from toolkit.AI.classes import LLMCallback
from toolkit.detect_case_patterns import prompts
import toolkit.detect_case_patterns.config as config


def get_intro():
file_path = os.path.join(os.path.dirname(__file__), "README.md")
Expand Down
5 changes: 3 additions & 2 deletions app/workflows/match_entity_records/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ Navigate to the `Detect record groups` tab to continue.

### Configuring the text embedding model

The interface on the left shows an empty selection box for `Attribute 1`. Within this field, the selectable values all have a suffix indicating their source dataset (here, `D1` or `D2`). Select `address::D1` and `street_address::D2` as the values for `Attribute 1`, and optionally enter either label (or a new label) for this attribute in the `Label (optional)` field. If no label is provided, the first value alphabetically will be used as the attribute label in the unified dataset.
The interface on the left shows an empty selection box for `Attribute 1`. Within this field, the selectable values all have a suffix indicating their source dataset (here, `D1` or `D2`).

Select `address::D1` and `street_address::D2` as the values for `Attribute 1`, and optionally enter either label (or a new label) for this attribute in the `Label (optional)` field. If no label is provided, the first value alphabetically will be used as the attribute label in the unified dataset.

Repeat this process to match the following pairs of attributes:

Expand All @@ -97,7 +99,6 @@ Repeat this process to match the following pairs of attributes:
- `country::D1` and `country_address::D2`
- `sector::D1` and `industry_sector::D2`
- `owner::D1` and `company_owner::D2`
- `city::D1` and `city_address::D2`
- `email::D1` and `email_address::D2`
- `phone::D1` and `phone_number::D2`

Expand Down
3 changes: 0 additions & 3 deletions app/workflows/match_entity_records/configure_embedding.py

This file was deleted.

2 changes: 1 addition & 1 deletion app/workflows/match_entity_records/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
#
import streamlit as st

import toolkit.match_entity_records.config as config
from app.util.constants import LOCAL_EMBEDDING_MODEL_KEY
from app.util.openai_wrapper import UIOpenAIConfiguration
from app.util.secrets_handler import SecretsHandler
from toolkit.AI.base_embedder import BaseEmbedder
from toolkit.AI.local_embedder import LocalEmbedder
from toolkit.AI.openai_embedder import OpenAIEmbedder
from toolkit.match_entity_records import config


def embedder(local_embedding: bool | None = False) -> BaseEmbedder:
Expand Down
2 changes: 1 addition & 1 deletion app/workflows/match_entity_records/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_session(self, prefix):
self.matching_max_rows_to_process = SessionVariable(0, prefix)
self.matching_mapped_atts = SessionVariable([], prefix)
self.matching_sentence_pair_scores = SessionVariable([], prefix)
self.matching_sentence_pair_jaccard_threshold = SessionVariable(0.0, prefix)
self.matching_sentence_pair_jaccard_threshold = SessionVariable(0.75, prefix)
self.matching_sentence_pair_embedding_threshold = SessionVariable(
DEFAULT_MAX_RECORD_DISTANCE, prefix
)
Expand Down
126 changes: 37 additions & 89 deletions app/workflows/match_entity_records/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os

import numpy as np
import pandas as pd
import polars as pl
import streamlit as st

Expand All @@ -17,20 +16,10 @@
from app.util import ui_components
from app.util.download_pdf import add_download_pdf
from toolkit.helpers.progress_batch_callback import ProgressBatchCallback
from toolkit.match_entity_records.config import AttributeToMatch
from toolkit.match_entity_records.detect import (
build_attributes_dataframe,
build_matches,
build_matches_dataset,
build_near_map,
build_nearest_neighbors,
build_sentence_pair_scores,
convert_to_sentences,
)
from toolkit.match_entity_records.prepare_model import (
build_attribute_list,
build_attribute_options,
format_dataset,
from toolkit.match_entity_records.api import MatchEntityRecords
from toolkit.match_entity_records.classes import (
AttributeToMatch,
RecordsModel,
)


Expand All @@ -42,6 +31,7 @@ def get_intro():
async def create(sv: rm_variables.SessionVariable, workflow=None) -> None:
sv_home = home_vars.SessionVariables("home")
ui_components.check_ai_configuration()
mer = MatchEntityRecords()

intro_tab, uploader_tab, process_tab, evaluate_tab, examples_tab = st.tabs(
[
Expand Down Expand Up @@ -76,8 +66,8 @@ async def create(sv: rm_variables.SessionVariable, workflow=None) -> None:
st.warning("Upload and select a file to continue")
else:
selected_df = pl.from_pandas(selected_df).lazy()
cols = ["", *selected_df.columns]
entity_col = ""
cols = selected_df.columns
entity_id_col = ""
ready = False
dataset = st.text_input(
"Dataset name",
Expand All @@ -90,12 +80,14 @@ async def create(sv: rm_variables.SessionVariable, workflow=None) -> None:
cols,
help="The column containing the name of the entity to be matched. This column is required.",
)
entity_col = st.selectbox(
entity_id_col = st.selectbox(
"Entity ID column (optional)",
cols,
help="The column containing the unique identifier of the entity to be matched. If left blank, a unique ID will be generated for each entity based on the row number.",
)
filtered_cols = [c for c in cols if c not in [entity_col, name_col, ""]]
filtered_cols = [
c for c in cols if c not in [entity_id_col, name_col, ""]
]
att_cols = st.multiselect(
"Entity attribute columns",
filtered_cols,
Expand All @@ -114,19 +106,22 @@ async def create(sv: rm_variables.SessionVariable, workflow=None) -> None:
disabled=not ready,
use_container_width=True,
):
sv.matching_dfs.value[dataset] = format_dataset(
selected_df.collect(),
att_cols,
name_col,
entity_col,
sv.matching_max_rows_to_process.value,
model = RecordsModel(
dataframe=selected_df.collect(),
name_column=name_col,
columns=att_cols,
dataframe_name=dataset,
id_column=entity_id_col,
)
dataset_added = mer.add_df_to_model(model)
sv.matching_dfs.value[dataset] = dataset_added
with b2:
if st.button(
"Reset data model",
disabled=len(sv.matching_dfs.value) == 0,
use_container_width=True,
):
mer.clear_model_dfs()
sv.matching_dfs.value = {}
sv.matching_merged_df.value = pl.DataFrame()
st.rerun()
Expand All @@ -135,6 +130,8 @@ async def create(sv: rm_variables.SessionVariable, workflow=None) -> None:
st.success(
f"Data model has **{len(sv.matching_dfs.value)}** datasets with **{recs}** total records."
)
if not mer.model_dfs:
mer.model_dfs = sv.matching_dfs.value

with process_tab:
if len(sv.matching_dfs.value) == 0:
Expand All @@ -143,7 +140,7 @@ async def create(sv: rm_variables.SessionVariable, workflow=None) -> None:
c1, c2 = st.columns([1, 1])
with c1:
st.markdown("##### Configure text embedding model")
attr_options = build_attribute_options(sv.matching_dfs.value)
attr_options = mer.attribute_options
sv.matching_mapped_atts.value = []

num_atts = 0
Expand Down Expand Up @@ -208,7 +205,6 @@ def att_ui(i, any_empty, changed, attsaa):
_, changed, attsa = att_ui(num_atts, any_empty, changed, attsa)
if changed:
st.rerun()
attributes_list = build_attribute_list(attsa)

local_embedding = st.toggle(
"Use local embeddings",
Expand All @@ -222,7 +218,7 @@ def att_ui(i, any_empty, changed, attsaa):
"Matching record distance (max)",
min_value=0.001,
max_value=1.0,
step=0.001,
step=0.01,
format="%f",
value=sv.matching_sentence_pair_embedding_threshold.value,
help="The maximum cosine distance between two records in the embedding space for them to be considered a match. Lower values will result in fewer closer matches overall.",
Expand Down Expand Up @@ -264,19 +260,9 @@ def att_ui(i, any_empty, changed, attsaa):
sv.matching_last_sentence_pair_embedding_threshold.value = (
sv.matching_sentence_pair_embedding_threshold.value
)
sv.matching_merged_df.value = build_attributes_dataframe(
sv.matching_dfs.value, attributes_list
)
sv.matching_merged_df.value = (
sv.matching_merged_df.value.with_columns(
(pl.col("Entity ID").cast(pl.Utf8))
+ "::"
+ pl.col("Dataset").alias("Unique ID")
)
) ###??
all_sentences_data = convert_to_sentences(
sv.matching_merged_df.value
)
sv.matching_merged_df.value = mer.build_model_df(attsa)
all_sentences_data = mer.sentences_vector_data

pb = st.progress(0, "Embedding text batches...")

def on_embedding_batch_change(current, total):
Expand Down Expand Up @@ -304,51 +290,24 @@ def on_embedding_batch_change(current, total):
)
for f in all_sentences
]
mer.embeddings = all_embeddings
mer.all_sentences = all_sentences

pb.empty()

distances, indices = build_nearest_neighbors(all_embeddings)
near_map = build_near_map(
distances,
indices,
all_sentences,
sv.matching_matches_df.value = mer.detect_record_groups(
sv.matching_sentence_pair_embedding_threshold.value,
)

sv.matching_sentence_pair_scores.value = (
build_sentence_pair_scores(
near_map, sv.matching_merged_df.value
)
)

merged_df = sv.matching_merged_df.value
entity_to_group, matches, pair_to_match = build_matches(
sv.matching_sentence_pair_scores.value,
merged_df,
sv.matching_sentence_pair_jaccard_threshold.value,
)


sv.matching_matches_df.value = pl.DataFrame(
list(matches),
schema=["Group ID", *sv.matching_merged_df.value.columns],
).sort(
by=["Group ID", "Entity name", "Dataset"], descending=False
)


sv.matching_matches_df.value = build_matches_dataset(
sv.matching_matches_df.value, pair_to_match, entity_to_group
)
st.rerun()
if len(sv.matching_matches_df.value) > 0:
st.markdown(
f"Identified **{len(sv.matching_matches_df.value['Group ID'].unique())}** record groups."
)
with c2:
data = sv.matching_matches_df.value
st.markdown("##### Record groups")
if len(sv.matching_matches_df.value) > 0:
data = sv.matching_matches_df.value
st.dataframe(
data, height=700, use_container_width=True, hide_index=True
)
Expand Down Expand Up @@ -416,11 +375,11 @@ def on_embedding_batch_change(current, total):

if len(sv.matching_evaluations.value) > 0:
try:
csv = pl.read_csv(io.StringIO(sv.matching_evaluations.value))
value = csv.drop_nulls()
jdf = sv.matching_matches_df.value.join(
value, on="Group ID", how="inner"
mer.evaluations_df = pl.read_csv(
io.StringIO(sv.matching_evaluations.value)
)
value = mer.evaluations_df.drop_nulls()

st.dataframe(
value.to_pandas(),
height=700,
Expand All @@ -431,14 +390,14 @@ def on_embedding_batch_change(current, total):
with c1:
st.download_button(
"Download AI match reports",
data=csv.write_csv(),
data=value.write_csv(),
file_name="record_group_match_reports.csv",
mime="text/csv",
)
with c2:
st.download_button(
"Download integrated results",
data=jdf.write_csv(),
data=mer.integrated_results.write_csv(),
file_name="integrated_record_match_results.csv",
mime="text/csv",
)
Expand All @@ -450,16 +409,5 @@ def on_embedding_batch_change(current, total):
"Download AI match report",
)

report = (
pd.DataFrame(sv.matching_evaluations.value).to_json()
if type(sv.matching_evaluations.value) == pl.DataFrame
else sv.matching_evaluations.value
)
# ui_components.build_validation_ui(
# sv.matching_report_validation.value,
# sv.matching_report_validation_messages.value,
# report,
# workflow,
# )
with examples_tab:
example_outputs_ui.create_example_outputs_ui(examples_tab, workflow)
Loading

0 comments on commit ab24e78

Please sign in to comment.