Skip to content

Commit

Permalink
CU-8693wtx4b: Add code and notebook to compare two models (#13)
Browse files Browse the repository at this point in the history
* CU-8693wtx4b: Add code and notebook to compare two models

* CU-8693wtx4: Add working output (synthetic data) to notebook

* CU-8693wtx4: Show cat data in raw format

* CU-8693wtx4: Add documentation to notebook

* CU-8693wtx4: Add per-cui output to code and notebook

* CU-8693wtx4: Add test to GHA workflow

* CU-8693wtx4: Remove old comments/commented code

* CU-8693wtx4: Add option to filter tally results by CUI; along with relevant tests.

* CU-8693wtx4: Add option to filter by CUI when comparing

* CU-8693wtx4: Add option to filter by CUI when comparing

* CU-8693wtx4: Fix typing issue

* CU-8693wtx4: Fix typing issue [#2]

* CU-8693wtx4: Add typing temp-issue to non-related files

* CU-8693wtx4: Add typing temp-issue to non-related files [#2]

* CU-8693wtx4b: Remove unnecessary whitespace

* CU-8693wtx4b: Add option to iterate over annotation pairs

* CU-8693wtx4b: Add tests for iteration over annotation pairs

* CU-8693wtx4b: Add option to iteration over subset of annotation pairs (per document ids)

* CU-8693wtx4b: Add some documentation for annotation pair iteration

* CU-8693wtx4b: Add annotation pair iteration to notebook

* CU-8693wtx4b: Add option (defaulted to True) to omit identical to annotation pair iteration

* CU-8693wtx4b: Update notebook to omit identical in annotation pair iteration

* CU-8693wtx4b: Update notebook with some extra documentation

* CU-8693wtx4b: Add further assertion to iteration with filter when testing

* CU-8693wtx4b: Fix non-overlapping annotations

* CU-8693wtx4b: Add more comprehensive test for non-overlapping annotations

* CU-8693wtx4b: Allow none-valued dicst in dict comparison for output

* CU-8693wtx4b: Add tests for none-valued dicts in output

* CU-8693wtx4b: Add better handling of empty dicst in comparison in output

* CU-8693wtx4b: Remove debug output

* CU-8693wtx4b: Add more useful nulled dicts along with tests for them

* CU-8693wtx4b: Update notebook with recent changes

* CU-8693wtx4b: Remove debug output

* CU-8693wtx4b: Fix typing issues

* CU-8693wtx4b: Remove ignore file from mct_analysis

* CU-8693wtx4b: Add option to recognise children and grandchildren as

* CU-8693wtx4b: Fix typo

* CU-8693wtx4b: Fix typo (#2)

* CU-8693wtx4b: Avoid annotating twice over the dataset

* CU-8693wtx4b: Fix issue with raw annotations being empty at tally time

* CU-869475m5q: Allow differentiating CUIs that one of the two models does not include

* CU-869475m5q: Update notebook with latest output

* CU-869475h56: Allow CUI filter to be specified by a file with a list of CUIs; Also make sure CUI filtering is used during annotation

* CU-869475h56: Update notebook with file based CUI filter information

* CU-869475m5q: Fix tests: add necessary per model CUI sets where applicable

* CU-869475h38: Allow including children for per-cui view

* CU-869475h38: Make per-cui names easier to read

* CU-869475h38: Update notebook

* CU-86948wv58: Add method to get CSV output of annotations

* CU-86948wv58: Improve documentation

* CU-86948wv58: Update model comparison with CSV output part

* CU-86948wv58: Fix CSV output relative start and end for annotations

* CU-86948wv58: Fix typing for annotations when creating CSV

* CU-86948wv58: Fix typing when creating sub-text for CSV

* CU-869498jf4: Add option to add children to CUI filter along with tests for getting said filter

* CU-869498jf4: Update notebook with notes regarding children in filter

* CU-8693wtx4b: Fix whitespace

* CU-869498jf4: Add more output regarding filter after adding children

* CU-869498jf4: Add children from both models if possible

* CU-869475h0e: Add option to compare to a base model + supervised trainig

* CU-869475h0e: Add tests for base model + supervised training; Add fake test data

* CU-869475h0e: Add note for supervised training-based comparison

* CU-869475h0e: Add missing notes for supervised training-based comparison

* CU-869475h78: Add markdown to dict comparison

* CU-869475h78: Fix notebook output for per-annotation differences

* CU-869475h78: Automatically use markdown when notebook output required

* CU-869475h78: Automatically detect when in notebook (by default)

* CU-869475h78: Update notebook with tabular / markdown output

* CU-8694ezhrn: Load documents as iterator instead of holding them in memory

* CU-8694ezhrn: Add option to omit raw text

* CU-8694ezhrn: Add tests when omitting raw text

* CU-8694ezhrn: Propgate keeping raw option properly

* CU-8694ezhrn: Propgate keeping raw option properly (2)

* CU-8694ezhrn: Do not store copies of existing dicts in annotation pairs

* CU-8693wtx4: Improve memory usage by saving the annotation differences on disk (by default) and reading them 100 at a time

* CU-8693wtx4: Allow filtering by comparison type when converting to CSV

* CU-8693wtx4: Add tests for filtering by comparison type when iterating to save as CSV

* CU-8693wtx4: Fix model type in case of using DB for annotation pairs

* CU-8694ggtgm: Improve initial documentation for model input

* CU-8694ggtgm: Add more accurate differences for two approaches in initial documentation

* CU-8694ggtgm: Add code for supervised training comparison (but commented)

* CU-8694ggtgm: Add widget input

* CU-8694ggtgm: Small documentation fix

* CU-8694ggtgm: Add new dependencies to requirementss file

* CU-8694ggtgm: Add input data validation

* CU-8694ggtgm: Move input data validation to its own module

* CU-8694ggtgm: Add further validation for medcat models

* CU-8694ggtgm: Add further validation for documents file

* CU-8694ggtgm: Add wildcard MCT export validation

* CU-8694ggtgm: Replace use of 3.9+ exclusive string method
  • Loading branch information
mart-r authored Jul 18, 2024
1 parent 1740e27 commit 404ad0d
Show file tree
Hide file tree
Showing 20 changed files with 4,923 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Test
run: |
python -m unittest discover
python -m unittest discover -s medcat/compare_models
# TODO - in the future, we might want to add automated tests for notebooks as well
# though it's not really possible right now since the notebooks are designed
# in a way that assumes interaction (i.e specifying model pack names)
62 changes: 62 additions & 0 deletions medcat/compare_models/cmp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Type, TypeVar, Generic, Iterable, Callable, Optional

import sqlite3
import re
from pydantic import BaseModel


T = TypeVar('T', bound=BaseModel)


def sanitize_table_name(name, max_length=64):
# Replace any characters not allowed in table names with underscores
name = re.sub(r'[^a-zA-Z0-9_$]', '_', name)
# Truncate the name if it's too long
name = name[:max_length]
return name


class SaveOptions(BaseModel):
use_db: bool = False
db_file_name: Optional[str] = None
clean_callback: Optional[Callable[[], None]] = None


class DifferenceDatabase(Generic[T]):

def __init__(self, db_file: str, part: str, model_type: Type[T],
batch_size: int = 100):
self.db_file = db_file
self.part = sanitize_table_name(part)
self.model_type = model_type
self.conn = sqlite3.connect(self.db_file)
self.cursor = self.conn.cursor()
self._create_table()
self._len = 0
self._batch_size = batch_size

def _create_table(self):
self.cursor.execute(f'''CREATE TABLE IF NOT EXISTS differences_{self.part}
(id INTEGER PRIMARY KEY, data TEXT)''')
self.conn.commit()

def append(self, difference: T):
data = difference.json()
self.cursor.execute(f"INSERT INTO differences_{self.part} (data) VALUES (?)", (data,))
self.conn.commit()
self._len += 1

def __iter__(self) -> Iterable[T]:
self.cursor.execute(f"SELECT data FROM differences_{self.part}")
while True:
rows = self.cursor.fetchmany(self._batch_size)
if not rows:
break
for row in rows:
yield self.model_type.parse_raw(row[0])

def __len__(self) -> int:
return self._len

def __del__(self):
self.conn.close()
166 changes: 166 additions & 0 deletions medcat/compare_models/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from typing import List, Tuple, Dict, Set, Optional, Union, Iterator
from functools import partial
import glob

from medcat.cat import CAT

import pandas as pd
import tqdm
import tempfile
import os

from compare_cdb import compare as compare_cdbs, CDBCompareResults
from compare_annotations import ResultsTally, PerAnnotationDifferences
from output import parse_and_show
from cmp_utils import SaveOptions
from validation import validate_input



def load_documents(file_name: str) -> Iterator[Tuple[str, str]]:
with open(file_name) as f:
df = pd.read_csv(f, names=["id", "text"])
if df.iloc[0].id == "id" and df.iloc[0].text == "text":
# removes the header
# but also messes up the index a little
df = df.iloc[1:, :]
yield from df.itertuples(index=False)


def do_counting(cat1: CAT, cat2: CAT,
ann_diffs: PerAnnotationDifferences) -> ResultsTally:
def cui2name(cat, cui):
if cui in cat.cdb.cui2preferred_name:
return cat.cdb.cui2preferred_name[cui]
all_names = cat.cdb.cui2names[cui]
# longest anme
return sorted(all_names, key=lambda name: len(name), reverse=True)[0]
res1 = ResultsTally(pt2ch=_get_pt2ch(cat1), cat_data=cat1.cdb.make_stats(),
cui2name=partial(cui2name, cat1))
res2 = ResultsTally(pt2ch=_get_pt2ch(cat2), cat_data=cat2.cdb.make_stats(),
cui2name=partial(cui2name, cat2))
for per_doc in tqdm.tqdm(ann_diffs.per_doc_results.values()):
res1.count(per_doc.raw1)
res2.count(per_doc.raw2)
return res1, res2


def _get_pt2ch(cat: CAT) -> Optional[Dict]:
return cat.cdb.addl_info.get("pt2ch", None)


def get_per_annotation_diffs(cat1: CAT, cat2: CAT, documents: Iterator[Tuple[str, str]],
show_progress: bool = True,
keep_raw: bool = True,
) -> PerAnnotationDifferences:
pt2ch1: Optional[Dict] = _get_pt2ch(cat1)
pt2ch2: Optional[Dict] = _get_pt2ch(cat2)
temp_file = tempfile.NamedTemporaryFile()
save_opts = SaveOptions(use_db=True, db_file_name=temp_file.name,
clean_callback=temp_file.close)
pad = PerAnnotationDifferences(pt2ch1=pt2ch1, pt2ch2=pt2ch2,
model1_cuis=set(cat1.cdb.cui2names),
model2_cuis=set(cat2.cdb.cui2names),
keep_raw=keep_raw,
save_options=save_opts)
for doc_id, doc in tqdm.tqdm(documents, disable=not show_progress):
pad.look_at_doc(cat1.get_entities(doc), cat2.get_entities(doc), doc_id, doc)
pad.finalise()
return pad


def load_cui_filter(filter_file: str) -> Set[str]:
with open(filter_file) as f:
str_list = f.read().split(',')
return set(item.strip() for item in str_list)


def _add_all_children(cat: CAT, cui_filter: Set[str], include_children: int) -> None:
if include_children <= 0:
return
if "pt2ch" not in cat.cdb.addl_info:
return
pt2ch = cat.cdb.addl_info["pt2ch"]
children = set(ch for cui in cui_filter for ch in pt2ch.get(cui, []))
if include_children > 1:
_add_all_children(cat, children, include_children=include_children-1)
cui_filter.update(children)


def load_and_train(model_pack_path: str, mct_export_path: str) -> CAT:
cat = CAT.load_model_pack(model_pack_path)
# NOTE: Allowing mct_export_path to contain wildcat ("*").
# And in such a case, iterating over all matching files
if "*" not in mct_export_path:
cat.train_supervised_from_json(mct_export_path)
else:
for file in glob.glob(mct_export_path):
cat.train_supervised_from_json(file)
return cat


def get_diffs_for(model_pack_path_1: str,
model_pack_path_2: str,
documents_file: str,
cui_filter: Optional[Union[Set[str], str]] = None,
show_progress: bool = True,
include_children_in_filter: Optional[int] = None,
supervised_train_comparison_model: bool = False,
keep_raw: bool = True,
) -> Tuple[CDBCompareResults, ResultsTally, ResultsTally, PerAnnotationDifferences]:
validate_input(model_pack_path_1, model_pack_path_2, documents_file, cui_filter, supervised_train_comparison_model)
documents = load_documents(documents_file)
if show_progress:
print("Loading [1]", model_pack_path_1)
cat1 = CAT.load_model_pack(model_pack_path_1)
if show_progress:
print("Loading [2]", model_pack_path_2)
if not supervised_train_comparison_model:
cat2 = CAT.load_model_pack(model_pack_path_2)
else:
if show_progress:
print("Reloading model pack 1", model_pack_path_1)
print("And subsequently training on", model_pack_path_2)
print("This may take a while, depending on the amount of "
"data is being trained on")
cat2 = load_and_train(model_pack_path_1, model_pack_path_2)
if show_progress:
print("Per annotations diff finding")
if cui_filter:
if isinstance(cui_filter, str):
cui_filter = load_cui_filter(cui_filter)
if show_progress:
print("Applying filter to CATs:", len(cui_filter), 'CUIs')
if include_children_in_filter:
if show_progress:
print("Adding all children of", include_children_in_filter,
"or lower level from first model")
_add_all_children(cat1, cui_filter, include_children_in_filter)
if show_progress:
print("After adding children from 1st model have a total of",
len(cui_filter), "CUIs")
_add_all_children(cat2, cui_filter, include_children_in_filter)
if show_progress:
print("After adding children from 2nd model have a total of",
len(cui_filter), "CUIs")
cat1.config.linking.filters.cuis = cui_filter
cat2.config.linking.filters.cuis = cui_filter
ann_diffs = get_per_annotation_diffs(cat1, cat2, documents, keep_raw=keep_raw)
if show_progress:
print("Counting [1&2]")
res1, res2 = do_counting(cat1, cat2, ann_diffs)
if show_progress:
print("CDB compare")
cdb_diff = compare_cdbs(cat1.cdb, cat2.cdb)
return cdb_diff, res1, res2, ann_diffs


def main(mpn1: str, mpn2: str, documents_file: str):
cdb_diff, res1, res2, ann_diffs = get_diffs_for(mpn1, mpn2, documents_file, show_progress=False)
print("Results:")
parse_and_show(cdb_diff, res1, res2, ann_diffs)


if __name__ == "__main__":
import sys
main(*sys.argv[1:])
Loading

0 comments on commit 404ad0d

Please sign in to comment.