Skip to content

Commit

Permalink
improve dataset schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
baixiac committed Apr 5, 2024
1 parent 1bdf43e commit 2eeb343
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion app/api/routers/preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_rendered_entities_from_trainer_export(request: Request,
entities.append({
"start": annotation["start"],
"end": annotation["end"],
"label": f"{annotation['cui']} ({'correct' if annotation['correct'] else 'incorrect'}{'; terminated' if annotation['killed'] else ''})",
"label": f"{annotation['cui']} ({'correct' if annotation.get('correct', True) else 'incorrect'}{'; terminated' if annotation.get('deleted', False) and annotation.get('killed', False) else ''})",
"kb_id": annotation["cui"],
"kb_url": "#",
})
Expand Down
4 changes: 2 additions & 2 deletions app/api/routers/supervised_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Union
from typing_extensions import Annotated

from fastapi import APIRouter, Depends, UploadFile, Query, Request, File
from fastapi import APIRouter, Depends, UploadFile, Query, Request, File, Form
from fastapi.responses import JSONResponse
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE

Expand All @@ -30,7 +30,7 @@ async def train_supervised(request: Request,
epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1,
lr_override: Annotated[Union[float, None], Query(description="The override of the initial learning rate", gt=0.0)] = None,
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None,
description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None,
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
files = []
file_names = []
Expand Down
10 changes: 6 additions & 4 deletions app/data/anno_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datasets
import json
from pathlib import Path
from typing import List, Iterable, Tuple, Dict
from typing import List, Iterable, Tuple, Dict, Optional
from utils import filter_by_concept_ids


Expand All @@ -24,7 +24,8 @@ def _info(self) -> datasets.DatasetInfo:
description="Annotation Dataset. This is a dataset containing flattened MedCAT Trainer export",
features=datasets.Features(
{
"name": datasets.Value("string"),
"project": datasets.Value("string"),
"name":datasets.Value("string"),
"text": datasets.Value("string"),
"starts": datasets.Value("string"), # Mlflow ColSpec schema does not support HF Dataset Sequence
"ends": datasets.Value("string"), # Mlflow ColSpec schema does not support HF Dataset Sequence
Expand Down Expand Up @@ -57,8 +58,9 @@ def generate_examples(filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]:
ends.append(str(annotation["end"]))
labels.append(annotation["cui"])
yield str(id_), {
"name": document["name"],
"text": document["text"],
"project": project.get("name"),
"name": document.get("name"),
"text": document.get("text"),
"starts": ",".join(starts),
"ends": ",".join(ends),
"labels": ",".join(labels),
Expand Down
2 changes: 1 addition & 1 deletion app/data/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def generate_examples(filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]:
with open(str(filepath), "r") as f:
texts = ijson.items(f, "item")
for text in texts:
yield str(id_), {"name": f"doc_{str(id_)}", "text": text}
yield str(id_), {"name": f"{str(id_)}", "text": text}
id_ += 1
4 changes: 2 additions & 2 deletions app/processors/metrics_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_iaa_scores_per_concept(export_file: Union[str, TextIO],
per_cui_metaanno_iia_pct = {}
per_cui_metaanno_cohens_kappa = {}
for cui, cui_metastate_pairs in cui_metastates.items():
per_cui_metaanno_iia_pct[cui] = len([1 for csp in cui_metastate_pairs if csp[0] == csp[1]]) / len(cui_metastate_pairs) * 100
per_cui_metaanno_iia_pct[cui] = len([1 for cmp in cui_metastate_pairs if cmp[0] == cmp[1]]) / len(cui_metastate_pairs) * 100
per_cui_metaanno_cohens_kappa[cui] = _get_cohens_kappa_coefficient(*map(list, zip(*cui_metastate_pairs)))

if return_df:
Expand Down Expand Up @@ -286,7 +286,7 @@ def get_iaa_scores_per_doc(export_file: Union[str, TextIO],
per_doc_metaanno_iia_pct = {}
per_doc_metaanno_cohens_kappa = {}
for doc_id, doc_metastate_pairs in doc_metastates.items():
per_doc_metaanno_iia_pct[str(doc_id)] = len([1 for dsp in doc_metastate_pairs if dsp[0] == dsp[1]]) / len(doc_metastate_pairs) * 100
per_doc_metaanno_iia_pct[str(doc_id)] = len([1 for dmp in doc_metastate_pairs if dmp[0] == dmp[1]]) / len(doc_metastate_pairs) * 100
per_doc_metaanno_cohens_kappa[str(doc_id)] = _get_cohens_kappa_coefficient(*map(list, zip(*doc_metastate_pairs)))

if return_df:
Expand Down
3 changes: 3 additions & 0 deletions tests/app/data/test_anno_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ def test_load_dataset():
trainer_export = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export_multi_projs.json")
dataset = datasets.load_dataset(anno_dataset.__file__, data_files={"annotations": trainer_export}, split="train", cache_dir="/tmp")
assert dataset.features.to_dict() == {
"project": {"dtype": "string", "_type": "Value"},
"name": {"dtype": "string", "_type": "Value"},
"text": {"dtype": "string", "_type": "Value"},
"starts": {"dtype": "string", "_type": "Value"},
"ends": {"dtype": "string", "_type": "Value"},
"labels": {"dtype": "string", "_type": "Value"},
}
assert len(dataset.to_list()) == 4
assert dataset.to_list()[0]["project"] == "MT Samples (Clone)"
assert dataset.to_list()[0]["name"] == "1687"
assert dataset.to_list()[0]["starts"] == "332,255,276,272"
assert dataset.to_list()[0]["ends"] == "355,267,282,275"
Expand All @@ -24,6 +26,7 @@ def test_generate_examples():
example_gen = anno_dataset.generate_examples([os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json")])
example = next(example_gen)
assert example[0] == "1"
assert "project" in example[1]
assert "name" in example[1]
assert "text" in example[1]
assert "starts" in example[1]
Expand Down
2 changes: 1 addition & 1 deletion tests/app/data/test_doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_load_dataset():
dataset = datasets.load_dataset(doc_dataset.__file__, data_files={"documents": sample_texts}, split="train", cache_dir="/tmp")
assert dataset.features.to_dict() == {"name": {"dtype": "string", "_type": "Value"}, "text": {"dtype": "string", "_type": "Value"}}
assert len(dataset.to_list()) == 15
assert dataset.to_list()[0]["name"] == "doc_1"
assert dataset.to_list()[0]["name"] == "1"


def test_generate_examples():
Expand Down
8 changes: 6 additions & 2 deletions tests/app/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,16 @@ def test_augment_annotations_case_insensitive():
[r"^\d{2,4}\s*[.\/]\s*\d{1,2}\s*[.\/]\s*\d{1,2}$"],
[r"^\d{1,2}$", r"^[-.\/]$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{2,4}$"],
[r"^\d{2,4}$", r"^[-.\/]$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{1,2}$"],
[r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{4}$"],
[r"^\d{4}$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$"],
[r"^\d{1,2}\s*$", r"-", r"^\s*\d{4}$"],
[r"^\d{1,2}\s*[\/]\s*\d{4}$"],
[r"^\d{4}\s*$", r"-", r"^\s*\d{1,2}$"],
[r"^\d{4}\s*[\/]\s*\d{1,2}$"],
[r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{4}$"],
[r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)(\s+\d{1,2})*$", r",", r"^\d{4}$"],
[r"^\d{4}\s*[-.\/]\s*(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$"],
[r"^\d{4}$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$"],
[r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$", r"^\d{4}$"],
[r"^(?:19\d\d|20\d\d)$"],
]
}, False)

Expand Down

0 comments on commit 2eeb343

Please sign in to comment.