Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-86956a6y5 improve comparison #15

Merged
merged 7 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions medcat/compare_models/comp_nbhelper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from ipyfilechooser import FileChooser
from ipywidgets import widgets
from IPython.display import display
import os
from typing import List, Optional


from compare import get_diffs_for
from output import parse_and_show, show_dict_deep, compare_dicts


_def_path = '../../models/modelpack'
_def_path = _def_path if os.path.exists(_def_path) else '.'


class NBComparer:

def __init__(self, model_path_1: str, model_path_2: str,
documents_file: str, doc_limit: int, is_mct_export_compare: bool,
cui_filter: str, filter_children: bool) -> None:
self.model_path_1 = model_path_1
self.model_path_2 = model_path_2
self.documents_file = documents_file
self.doc_limit = doc_limit
self.is_mct_export_compare = is_mct_export_compare
self.cui_filter = cui_filter
self.filter_children = filter_children
self._run_comparison()

def _run_comparison(self):
(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs) = get_diffs_for(
self.model_path_1, self.model_path_2, self.documents_file,
cui_filter=self.cui_filter, include_children_in_filter=self.filter_children,
supervised_train_comparison_model=self.is_mct_export_compare, doc_limit=self.doc_limit)

def show_all(self):
parse_and_show(self.cdb_comp, self.tally1, self.tally2, self.ann_diffs)

def show_per_document(self, limit: int = -1, print_delimiter: bool = True,
ignore_empty: bool = True):
cnt = 0
for key in self.ann_diffs.per_doc_results.keys():
comp_dict = self.ann_diffs.per_doc_results[key].nr_of_comparisons
if not ignore_empty or comp_dict: # ignore empty ones
if print_delimiter:
print('='*20,f'\n{key}', f'\n{"="*20}')
show_dict_deep(self.ann_diffs.per_doc_results[key].nr_of_comparisons)
cnt += 1
if limit > -1 and cnt == limit:
break

def diffs_to_csv(self, file_path: str) -> None:
self.ann_diffs.to_csv(file_path)

def compare_for_cui(self, cui: str, include_children: int = 2) -> None:
per_cui1 = self.tally1.get_for_cui(cui, include_children=include_children)
per_cui2 = self.tally2.get_for_cui(cui, include_children=include_children)
compare_dicts(per_cui1, per_cui2)

def show_docs(self, docs: List[str], show_delimiter: bool = True,
omit_identical: bool = True):
for doc_name, pair in self.ann_diffs.iter_ann_pairs(docs=docs, omit_identical=omit_identical):
if show_delimiter:
print('='*20,f'\n{doc_name} ({pair.comparison_type})', f'\n{"="*20}')
# NOTE: if only one of the two has an annotation, the other one will be None
# the following will deal with that automatically, though
compare_dicts(pair.one, pair.two)


class NBInputter:
models_overall_title = "Models and data"
mc1_title = "Choose model 1"
mc2_title = "Choose model 2 (or an MCT export)"
docs_title = "Choose the documents file (.csv with 'text' field)"
docs_limit_title = "Limit the number of documents to run (-1 to disable)"
mct_export_title = "Is the 2nd path an MCT export (instead of a model)?"
cui_filter_title_overall = "CUI Filter"
cui_filter_title_file_chooser = "Choose file with comma-separated CUIs"
cui_filter_title_text = "List comma-separated CUIs"
cui_children_title = "How many layers of children of concepts to include?"

def __init__(self) -> None:
self.model1_chooser = FileChooser(_def_path)
self.model2_chooser = FileChooser(_def_path)
self.documents_chooser = FileChooser(".")
self.doc_limit = widgets.IntText(-1)
self.ckbox = widgets.Checkbox(description="MCT export compare")

self.cui_filter_chooser = FileChooser(".", description="The CUI filter file")
self.cui_filter_box = widgets.Textarea(description="CUI list")
self.cui_children = widgets.IntText(description="Children", value=-1)

def show_all(self):
model_choosers = widgets.VBox([
widgets.HTML(f"<h2>{self.models_overall_title}</h2>"),
widgets.VBox([widgets.Label(self.mc1_title), self.model1_chooser]),
widgets.VBox([widgets.Label(self.mc2_title), self.model2_chooser]),
widgets.VBox([widgets.Label(self.docs_title), self.documents_chooser]),
widgets.VBox([widgets.Label(self.docs_limit_title), self.doc_limit]),
widgets.VBox([widgets.Label(self.mct_export_title), self.ckbox])
])

cui_filter = widgets.VBox([
widgets.HTML(f"<h2>{self.cui_filter_title_overall}</h2>"),
widgets.VBox([widgets.Label(self.cui_filter_title_file_chooser), self.cui_filter_chooser]),
widgets.VBox([widgets.Label(self.cui_filter_title_text), self.cui_filter_box]),
widgets.VBox([widgets.Label(self.cui_children_title), self.cui_children])
])

# Combine all sections into a main VBox
main_box = widgets.VBox([
model_choosers,
cui_filter
])
display(main_box)


def _get_params(self):
model_path_1 = self.model1_chooser.selected
model_path_2 = self.model2_chooser.selected
documents_file = self.documents_chooser.selected
doc_limit = self.doc_limit.value
is_mct_export_compare = self.ckbox.value
if not is_mct_export_compare:
print(f"For models, selected:\nModel1: {model_path_1}\nModel2: {model_path_2}"
f"\nDocuments: {documents_file}")
else:
print(f"Selected:\nModel: {model_path_1}\nMCT export: {model_path_2}"
f"\nDocuments: {documents_file}")
# CUI filter
cui_filter = None
filter_children = None
if self.cui_filter_chooser.selected:
cui_filter = self.cui_filter_chooser.selected
elif self.cui_filter_box.value:
cui_filter = self.cui_filter_box.value
if self.cui_children.value and self.cui_children.value > 0:
filter_children = self.cui_children.value
print(f"For CUI filter, selected:\nFilter: {cui_filter}\nChildren: {filter_children}")
return (model_path_1, model_path_2, documents_file, doc_limit, is_mct_export_compare, cui_filter, filter_children)

def get_comparison(self) -> NBComparer:
return NBComparer(*self._get_params())
27 changes: 18 additions & 9 deletions medcat/compare_models/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import tqdm
import tempfile
import os
from itertools import islice

from compare_cdb import compare as compare_cdbs, CDBCompareResults
from compare_annotations import ResultsTally, PerAnnotationDifferences
Expand All @@ -17,18 +17,22 @@



def load_documents(file_name: str) -> Iterator[Tuple[str, str]]:
def load_documents(file_name: str, doc_limit: int = -1) -> Iterator[Tuple[str, str]]:
with open(file_name) as f:
df = pd.read_csv(f, names=["id", "text"])
if df.iloc[0].id == "id" and df.iloc[0].text == "text":
# removes the header
# but also messes up the index a little
df = df.iloc[1:, :]
yield from df.itertuples(index=False)
if doc_limit == -1:
yield from df.itertuples(index=False)
else:
yield from islice(df.itertuples(index=False), doc_limit)


def do_counting(cat1: CAT, cat2: CAT,
ann_diffs: PerAnnotationDifferences) -> ResultsTally:
ann_diffs: PerAnnotationDifferences,
doc_limit: int = -1) -> ResultsTally:
def cui2name(cat, cui):
if cui in cat.cdb.cui2preferred_name:
return cat.cdb.cui2preferred_name[cui]
Expand All @@ -39,7 +43,8 @@ def cui2name(cat, cui):
cui2name=partial(cui2name, cat1))
res2 = ResultsTally(pt2ch=_get_pt2ch(cat2), cat_data=cat2.cdb.make_stats(),
cui2name=partial(cui2name, cat2))
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values()):
total = doc_limit if doc_limit != -1 else None
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values(), total=total):
res1.count(per_doc.raw1)
res2.count(per_doc.raw2)
return res1, res2
Expand All @@ -52,6 +57,7 @@ def _get_pt2ch(cat: CAT) -> Optional[Dict]:
def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str, str]],
show_progress: bool = True,
keep_raw: bool = True,
doc_limit: int = -1
) -> PerAnnotationDifferences:
pt2ch1: Optional[Dict] = _get_pt2ch(cat1)
pt2ch2: Optional[Dict] = _get_pt2ch(cat2)
Expand All @@ -63,7 +69,8 @@ def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str
model2_cuis=set(cat2.cdb.cui2names),
keep_raw=keep_raw,
save_options=save_opts)
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress):
total = doc_limit if doc_limit != -1 else None
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress, total=total):
pad.look_at_doc(cat1.get_entities(doc), cat2.get_entities(doc), doc_id, doc)
pad.finalise()
return pad
Expand Down Expand Up @@ -107,9 +114,10 @@ def get_diffs_for(model_pack_path_1: str,
include_children_in_filter: Optional[int] = None,
supervised_train_comparison_model: bool = False,
keep_raw: bool = True,
doc_limit: int = -1,
) -> Tuple[CDBCompareResults, ResultsTally, ResultsTally, PerAnnotationDifferences]:
validate_input(model_pack_path_1, model_pack_path_2, documents_file, cui_filter, supervised_train_comparison_model)
documents = load_documents(documents_file)
documents = load_documents(documents_file, doc_limit=doc_limit)
if show_progress:
print("Loading [1]", model_pack_path_1)
cat1 = CAT.load_model_pack(model_pack_path_1)
Expand Down Expand Up @@ -145,10 +153,11 @@ def get_diffs_for(model_pack_path_1: str,
len(cui_filter), "CUIs")
cat1.config.linking.filters.cuis = cui_filter
cat2.config.linking.filters.cuis = cui_filter
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw)
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw,
doc_limit=doc_limit)
if show_progress:
print("Counting [1&2]")
res1, res2 = do_counting(cat1, cat2, ann_diffs)
res1, res2 = do_counting(cat1, cat2, ann_diffs, doc_limit=doc_limit)
if show_progress:
print("CDB compare")
cdb_diff = compare_cdbs(cat1.cdb, cat2.cdb)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
289001005,289004002,226207007,165224005,129043005,129040008,129041007,704440004,704439001,704437004,129065005,129039006,129035000,165232002,1080000000000000,716422006,45850009,284908004,129045003,129062008,714887007,714916007,719024002,715127003,714915006,282882001,302040002,302043000,165243005,365112008,165255004,105504002,301563003,301497008,160680006,301589003,165248001,165249009,362000000000000,270469004,160729004,248000000000000,160734000,405807003,160685001,285035003,285038001,285034004,428483004,1912002,431188001,864000000000000,301563003,1070000000000000,301497008,78459008,284915007,307439001,229798009,229799001,229797004,31000000000000,818000000000000,763692001,404934007,863721000000101,1100000017,1100000016,1100000030,1100000030,1100000011,863721000000101,863721000000101,863721000000101,282884000,282884000,282884000,1100000027,361721000000103,1100000028,361721000000103,1100000028,361721000000103,1100000027,361721000000103,361721000000103,1100000027,1100000028,1100000027,361721000000103,361721000000103,1100000028,1100000031,31031000119102,310131003,895488007,895488007,394923006,394923006,394923006,394923006,394923006,248171000000108,248171000000108,248171000000108,248171000000108,1100000015,1100000015,1100000012,1100000012,1100000013,1100000012,302046008,25711000087100,165233007,1100000029,718705001,718360006,282871009,895486006,895486006,895486006,895486006,699650006,699650006,8510008,273302005,306171006,257301003,404930003,404930003,404930003,224221006,261001000,184156005,184156005,183376001,154091000119106,301627005,301627005,725594005,445414007,165803005,323701000000101,72042002,24029004,282971008,10610811000001107,161903000,979501000000100,301477003,282966001,1149222004,371153006,311925007,225602000,763264000,249902000,249902000,223600005,386323002,37013008,205511000000108,325831000000100,1073861000000108,273469003,129032002,286489001,761481000000107,129072006,1073311000000100,286489001,286490005,129026007,160689007,286493007,129031009,1069991000000102,1071641000000109,960681000000109
Loading
Loading