From 20125f783345f2bfa7bd00c153e9d3ee05192b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 25 Aug 2023 17:27:43 +0200 Subject: [PATCH] refacto: redesign pipeline scorers, add input and output spans params to trainable_ner #203 --- docs/pipelines/trainable/ner.md | 8 +- edsnlp/core/component.py | 28 +-- edsnlp/core/pipeline.py | 128 +++++++--- edsnlp/core/registry.py | 11 - edsnlp/pipelines/trainable/ner/__init__.py | 2 +- edsnlp/pipelines/trainable/ner/factory.py | 66 +++--- edsnlp/pipelines/trainable/ner/ner.py | 260 ++++++++++++--------- edsnlp/scorers/__init__.py | 0 edsnlp/scorers/ner.py | 116 +++++++++ edsnlp/scorers/speed.py | 23 ++ tests/pipelines/trainable/test_ner.py | 2 + tests/training/config.cfg | 20 +- tests/training/test_train.py | 81 ++----- 13 files changed, 453 insertions(+), 292 deletions(-) create mode 100644 edsnlp/scorers/__init__.py create mode 100644 edsnlp/scorers/ner.py create mode 100644 edsnlp/scorers/speed.py diff --git a/docs/pipelines/trainable/ner.md b/docs/pipelines/trainable/ner.md index 9929378ae..b80c82efb 100644 --- a/docs/pipelines/trainable/ner.md +++ b/docs/pipelines/trainable/ner.md @@ -88,13 +88,7 @@ The pipeline component can be configured using the following parameters :
-::: edsnlp.pipelines.trainable.nested_ner.factory.create_component - options: - only_parameters: true - -The default model `eds.nested_ner_model.v1` can be configured using the following parameters : - -::: edsnlp.pipelines.trainable.nested_ner.stack_crf_ner.create_model +::: edsnlp.pipelines.trainable.ner.factory.create_component options: only_parameters: true diff --git a/edsnlp/core/component.py b/edsnlp/core/component.py index f722bf482..6d7406226 100644 --- a/edsnlp/core/component.py +++ b/edsnlp/core/component.py @@ -43,7 +43,7 @@ def wrapped(self: "TorchComponent", doc: Doc): if self.nlp._cache is None: return fn(self, doc) cache_id = hash((id(self), "preprocess", id(doc))) - if cache_id in self.nlp._cache: + if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache: return self.nlp._cache[cache_id] res = fn(self, doc) self.nlp._cache[cache_id] = res @@ -58,7 +58,9 @@ def wrapped(self: "TorchComponent", doc: Doc): if self.nlp._cache is None: return fn(self, doc) cache_id = hash((id(self), "preprocess_supervised", id(doc))) - if cache_id in self.nlp._cache.setdefault(self, {}): + if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache.setdefault( + self, {} + ): return self.nlp._cache[cache_id] res = fn(self, doc) self.nlp._cache[cache_id] = res @@ -75,7 +77,7 @@ def wrapped(self: "TorchComponent", batch: Dict, device: torch.device): cache_id = hash((id(self), "collate", hash_batch(batch))) if self.nlp._cache is None or cache_id is None: return fn(self, batch, device) - if cache_id in self.nlp._cache: + if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache: return self.nlp._cache[cache_id] res = fn(self, batch, device) self.nlp._cache[cache_id] = res @@ -92,7 +94,7 @@ def wrapped(self: "TorchComponent", batch): cache_id = hash((id(self), "collate", hash_batch(batch))) if self.nlp._cache is None or cache_id is None: return fn(self, batch) - if cache_id in self.nlp._cache: + if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache: return self.nlp._cache[cache_id] res = fn(self, batch) self.nlp._cache[cache_id] = res @@ -298,24 +300,6 @@ def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]: """ return self.preprocess(doc) - def clean_gold_for_evaluation(self, gold: Doc) -> Doc: - """ - Clean the gold document before evaluation. - Only the attributes that are predicted by the component should be removed. - By default, this is a no-op. - - Parameters - ---------- - gold: Doc - Gold document - - Returns - ------- - Doc - The document without attributes that should be predicted - """ - return gold - def pipe(self, docs: Iterable[Doc], batch_size=1) -> Iterable[Doc]: """ Applies the component on a collection of documents. It is recommended to use diff --git a/edsnlp/core/pipeline.py b/edsnlp/core/pipeline.py index 72b28b0e3..aaba36b4a 100644 --- a/edsnlp/core/pipeline.py +++ b/edsnlp/core/pipeline.py @@ -1,6 +1,7 @@ # ruff: noqa: E501 import copy import functools +import inspect import os import shutil import time @@ -35,6 +36,7 @@ from spacy.util import get_lang_class from spacy.vocab import Vocab, create_vocab from tqdm import tqdm +from typing_extensions import NotRequired, TypedDict from edsnlp.core.registry import PIPE_META, CurriedFactory, FactoryMeta from edsnlp.utils.collections import ( @@ -62,6 +64,16 @@ class CacheEnum(str, Enum): Pipe = TypeVar("Pipe", bound=Callable[[Doc], Doc]) +ScorerType = Union[ + Callable[[Iterable[Doc]], Dict[str, Any]], + Callable[[Iterable[Doc], float], Dict[str, Any]], +] + + +class ScoringConfig(TypedDict): + pipes: NotRequired[Union[str, List[str]]] + scorer: ScorerType + class Pipeline: """ @@ -80,6 +92,7 @@ def __init__( batch_size: Optional[int] = 4, vocab_config: Type[BaseDefaults] = BaseDefaults, meta: Dict[str, Any] = None, + scorers: Dict[str, Union[ScoringConfig, ScorerType]] = None, ): """ Parameters @@ -124,7 +137,9 @@ def __init__( self._path: Optional[Path] = None self.meta = dict(meta) if meta is not None else {} self.lang: str = lang + self.scorers = scorers or {} self._cache: Optional[Dict] = None + self._cache_is_writeonly = False @property def pipeline(self) -> List[Tuple[str, Pipe]]: @@ -412,9 +427,12 @@ def cache(self): """ Enable caching for all (trainable) components in the pipeline """ - self._cache = {} + was_not_cached = not self._cache + if was_not_cached: + self._cache = {} yield - self._cache = None + if was_not_cached: + self._cache = None def torch_components( self, disable: Sequence[str] = () @@ -493,18 +511,19 @@ def from_config( config["nlp"]["components"] = Reference("components") config = config["nlp"] - config = Config(config).resolve(root=root_config) + config = dict(Config(config).resolve(root=root_config)) + components = config.pop("components", {}) + pipeline = config.pop("pipeline", ()) + tokenizer = config.pop("tokenizer", None) + disable = (config.pop("disabled", ()), disable) nlp = Pipeline( vocab=vocab, - create_tokenizer=config.get("tokenizer"), - lang=config["lang"], + create_tokenizer=tokenizer, meta=meta, + **config, ) - components = config.get("components", {}) - pipeline = config.get("pipeline", ()) - # Since components are actually resolved as curried factories, # we need to instantiate them here for name, component in components.items(): @@ -722,46 +741,81 @@ def score(self, docs: Sequence[Doc], batch_size: int = None) -> Dict[str, Any]: import torch from spacy.training import Example - inputs: Sequence[Doc] = copy.deepcopy(docs) - golds: Sequence[Doc] = docs - - scored_components = {} - # Predicting intermediate steps preds = defaultdict(lambda: []) if batch_size is None: batch_size = self.batch_size - total_duration = 0 + + scorers_by_pipes = defaultdict(lambda: {}) + for scorer_name, scorer in self.scorers.items(): + if isinstance(scorer, dict) and "scorer" in scorer: + pipe_names = scorer.get("pipes", self.pipe_names) + actual_scorer = scorer["scorer"] + if isinstance(pipe_names, str): + pipe_names = [pipe_names] + if pipe_names is None: + pipe_names = self.pipe_names + else: + pipe_names = self.pipe_names + actual_scorer = scorer + scorers_by_pipes[tuple(pipe_names)][scorer_name] = actual_scorer + + speed_metric_names = { + name + for _, scorers_group in scorers_by_pipes.items() + for name, scorer in scorers_group.items() + if "duration" in inspect.signature(scorer).parameters + } + pipes_to_duration = { + pipe_names: 0.0 + for pipe_names in scorers_by_pipes.keys() + if speed_metric_names & set(scorers_by_pipes[pipe_names]) + } + with self.train(False), torch.no_grad(): # type: ignore - for batch in batchify( - tqdm(inputs, "Scoring components"), batch_size=batch_size + for gold_batch in batchify( + tqdm(docs, "Scoring components"), batch_size=batch_size ): with self.cache(): - for name, pipe in self.pipeline[::-1]: - if hasattr(pipe, "clean_gold_for_evaluation"): - batch = [ - pipe.clean_gold_for_evaluation(doc) for doc in batch - ] + for pipe_names in scorers_by_pipes.keys(): + timed = speed_metric_names & set(scorers_by_pipes[pipe_names]) + + if timed: + self._cache_is_writeonly = True + + batch = copy.deepcopy(gold_batch) + t0 = time.time() - if hasattr(pipe, "batch_process"): - batch = pipe.batch_process(batch) - else: - batch = [pipe(doc) for doc in batch] - total_duration += time.time() - t0 - if getattr(pipe, "score", None) is not None: - scored_components[name] = pipe - preds[name].extend(copy.deepcopy(batch)) + for pipe_name in pipe_names: + pipe = self.get_pipe(pipe_name) + if hasattr(pipe, "batch_process"): + batch = pipe.batch_process(batch) + else: + batch = [pipe(doc) for doc in batch] - metrics: Dict[str, Any] = { - "speed": len(inputs) / total_duration, - } - for name, pipe in scored_components.items(): - metrics[name] = pipe.score( - [Example(p, g) for p, g in zip(preds[name], golds)] - ) + t1 = time.time() + + if timed: + pipes_to_duration[pipe_names] += t1 - t0 + self._cache_is_writeonly = False + + preds[pipe_names].extend(batch) + + results: Dict[str, Any] = {} + for pipe_names, preds in preds.items(): + for scorer_name, scorer in scorers_by_pipes[pipe_names].items(): + if scorer_name in speed_metric_names: + results[scorer_name] = scorer( + [Example(p, g) for p, g in zip(preds, docs)], + duration=pipes_to_duration[pipe_names], + ) + else: + results[scorer_name] = scorer( + [Example(p, g) for p, g in zip(preds, docs)], + ) - return metrics + return results def to_disk( self, path: Union[str, Path], *, exclude: Sequence[str] = FrozenList() diff --git a/edsnlp/core/registry.py b/edsnlp/core/registry.py index 2353502b2..4b15d21fc 100644 --- a/edsnlp/core/registry.py +++ b/edsnlp/core/registry.py @@ -55,7 +55,6 @@ class CurriedFactory: def __init__(self, func, kwargs): self.kwargs = kwargs self.factory = func - # self.factory_name = factory_name self.instantiated = None def instantiate( @@ -85,16 +84,6 @@ def instantiate( **kwargs, } ) - # Config._store_resolved( - # obj.instantiated, - # Config( - # { - # "@factory": obj.factory_name, - # **kwargs, - # } - # ), - # ) - # PIPE_META[obj.instantiated] = obj.meta return obj.instantiated elif isinstance(obj, dict): return { diff --git a/edsnlp/pipelines/trainable/ner/__init__.py b/edsnlp/pipelines/trainable/ner/__init__.py index e023a2304..549d2fc77 100644 --- a/edsnlp/pipelines/trainable/ner/__init__.py +++ b/edsnlp/pipelines/trainable/ner/__init__.py @@ -1 +1 @@ -from .factory import create_component, create_ner_exact_scorer +from .factory import create_component diff --git a/edsnlp/pipelines/trainable/ner/factory.py b/edsnlp/pipelines/trainable/ner/factory.py index b62973deb..470159413 100644 --- a/edsnlp/pipelines/trainable/ner/factory.py +++ b/edsnlp/pipelines/trainable/ner/factory.py @@ -1,38 +1,19 @@ -from typing import Any, Callable, Dict, Iterable, List +from typing import Dict, Optional, Union -from confit import Config -from spacy.tokens import Doc, Span -from spacy.training import Example +from pydantic.types import StrictStr +from typing_extensions import Literal from edsnlp import registry from edsnlp.core import PipelineProtocol from edsnlp.core.component import BatchInput, TorchComponent +from edsnlp.utils.span_getters import ListStr, SpanGetter from ..embeddings.typing import WordEmbeddingBatchOutput -from .ner import CRFMode, TrainableNER, make_span_getter, nested_ner_exact_scorer - -ner_default_config = """ -[root] -mode = "joint" - -[root.span_getter] -@misc = "span_getter" - -[root.scorer] -@scorers = "eds.ner_exact_scorer" -""" - -NER_DEFAULTS = Config.from_str(ner_default_config)["root"] - - -@registry.scorers.register("eds.ner_exact_scorer") -def create_ner_exact_scorer(): - return nested_ner_exact_scorer +from .ner import INFER, TrainableNER @registry.factory.register( "eds.ner", - default_config=NER_DEFAULTS, requires=["doc.ents", "doc.spans"], assigns=["doc.ents", "doc.spans"], default_score_weights={ @@ -42,13 +23,15 @@ def create_ner_exact_scorer(): }, ) def create_component( - nlp: PipelineProtocol, - name: str, + nlp: Optional[PipelineProtocol] = None, + name: Optional[str] = None, + *, embedding: TorchComponent[WordEmbeddingBatchOutput, BatchInput], - labels: List[str] = [], - span_getter: Callable[[Doc], Iterable[Span]] = make_span_getter(), - mode: CRFMode = CRFMode.joint, - scorer: Callable[[Iterable[Example]], Dict[str, Any]] = create_ner_exact_scorer(), + to_ents: Union[bool, ListStr] = INFER, + to_span_groups: Union[StrictStr, Dict[str, Union[bool, ListStr]]] = INFER, + labels: Optional[ListStr] = INFER, + target_span_getter: SpanGetter = {"ents": True}, + mode: Literal["independent", "joint", "marginal"] = "joint", ): """ Initialize a general named entity recognizer (with or without nested or @@ -62,22 +45,31 @@ def create_component( Name of the component embedding: TorchComponent[WordEmbeddingBatchOutput, BatchInput] The word embedding component + target_span_getter: Callable[[Doc], Iterable[Span]] + Method to call to get the gold spans from a document, for scoring or training. + By default, takes all entities in `doc.ents`, but we recommend you specify + a given span group name instead. labels: List[str] The labels to predict. The labels can also be inferred from the data during `nlp.post_init(...)` - span_getter: Callable[[Doc], Iterable[Span]] - Method to call to get the gold spans from a document, for scoring or training - mode: CRFMode + to_ents: ListStrOrBool + Whether to put predictions in `doc.ents`. `to_ents` can be: + - a boolean to put all or no predictions in `doc.ents` + - a list of str to filter predictions by label + to_span_groups: Union[str, Dict[str, ListStrOrBool]] + If and how to put predictions in `doc.spans`. `to_span_groups` can be: + - a string to put all predictions to a given span group (e.g. "ner-preds") + - a dict mapping group names to a list of str to filter predictions by label + mode: Literal["independent", "joint", "marginal"] The CRF mode to use: independent, joint or marginal - scorer: Optional[Callable[[Iterable[Example]], Dict[str, Any]]] - Method to call to score predictions """ return TrainableNER( nlp=nlp, name=name, embedding=embedding, + to_ents=to_ents, + to_span_groups=to_span_groups, labels=labels, - span_getter=span_getter, + target_span_getter=target_span_getter, mode=mode, - scorer=scorer, ) diff --git a/edsnlp/pipelines/trainable/ner/ner.py b/edsnlp/pipelines/trainable/ner/ner.py index 2e53a08d8..b224d944b 100644 --- a/edsnlp/pipelines/trainable/ner/ner.py +++ b/edsnlp/pipelines/trainable/ner/ner.py @@ -1,16 +1,23 @@ from __future__ import annotations +import warnings +from collections import defaultdict from enum import Enum -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch +from pydantic import StrictStr from spacy.tokens import Doc, Span -from spacy.training import Example -from typing_extensions import NotRequired, TypedDict +from typing_extensions import Literal, NotRequired, TypedDict +from edsnlp import Pipeline from edsnlp.core.component import TorchComponent -from edsnlp.core.registry import registry from edsnlp.utils.filter import filter_spans +from edsnlp.pipelines.base import ( + SpanGetter, + SpanGetterMapping, + get_spans, +) from ..embeddings.typing import BatchInput, WordEmbeddingBatchOutput from ..layers.crf import MultiLabelBIOULDecoder @@ -27,7 +34,6 @@ { "loss": Optional[torch.Tensor], "tags": Optional[torch.Tensor], - "mask": Optional[torch.Tensor], }, ) @@ -37,97 +43,26 @@ class CRFMode(str, Enum): joint = "joint" marginal = "marginal" - -# separate from make_span_getter to be picklable -def span_getter(doc): - return doc.ents - - -@registry.misc.register("span_getter") -def make_span_getter(): - return span_getter - - -def nested_ner_exact_scorer(examples: Iterable[Example], **cfg) -> Dict[str, Any]: - """ - Scores the extracted entities that may be overlapping or nested - by looking in `doc.ents`, and `doc.spans`. - - Parameters - ---------- - examples: Iterable[Example] - cfg: Dict[str] - - labels: Iterable[str] labels to take into account - - spans_labels: Iterable[str] span group names to look into for entities - - Returns - ------- - Dict[str, Any] - """ - labels = set(cfg["labels"]) if "labels" in cfg is not None else None - spans_labels = cfg.get("spans_labels", None) - - pred_spans = set() - gold_spans = set() - for eg_idx, eg in enumerate(examples): - for span in ( - *eg.predicted.ents, - *( - span - for name in ( - spans_labels if spans_labels is not None else eg.reference.spans - ) - for span in eg.predicted.spans.get(name, ()) - ), - ): - if labels is None or span.label_ in labels: - pred_spans.add((eg_idx, span.start, span.end, span.label_)) - - for span in ( - *eg.reference.ents, - *( - span - for name in ( - spans_labels if spans_labels is not None else eg.reference.spans - ) - for span in eg.reference.spans.get(name, ()) - ), - ): - if labels is None or span.label_ in labels: - gold_spans.add((eg_idx, span.start, span.end, span.label_)) - - tp = len(pred_spans & gold_spans) - - return { - "ents_p": tp / len(pred_spans) if pred_spans else float(tp == len(pred_spans)), - "ents_r": tp / len(gold_spans) if gold_spans else float(tp == len(gold_spans)), - "ents_f": 2 * tp / (len(pred_spans) + len(gold_spans)) - if pred_spans or gold_spans - else float(len(pred_spans) == len(gold_spans)), - "support": len(gold_spans), - } - - -@registry.factory.register("eds.ner") class TrainableNER(TorchComponent[NERBatchOutput, NERBatchInput]): def __init__( self, - nlp, - name: str, + nlp: Pipeline = None, + name: Optional[str] = None, + *, embedding: TorchComponent[WordEmbeddingBatchOutput, BatchInput], - labels: List[str], - span_getter: Callable[[Doc], Iterable[Span]], - mode: CRFMode, - scorer: Callable[[Iterable[Example]], Dict[str, Any]], + target_span_getter: SpanGetter, + labels: Optional[List[str]] = None, + to_ents: Union[bool, List[str]] = None, + to_span_groups: Union[StrictStr, Dict[str, Union[bool, List[str]]]] = None, + mode: Literal["independent", "joint", "marginal"], ): super().__init__(nlp, name) - self.name = name + self.embedding = embedding - self.span_getter = span_getter - self.labels = list(labels) + self.labels = labels self.linear = torch.nn.Linear( self.embedding.output_size, - len(labels) * 5, + 0 if labels is None else (len(labels) * 5), ) self.crf = MultiLabelBIOULDecoder( 1, @@ -135,28 +70,113 @@ def __init__( learnable_transitions=False, ) self.mode = mode - self.scorer = scorer + self.to_ents = to_ents + self.to_span_groups: Dict[str, Union[bool, List[str]]] = ( + {to_span_groups: True} + if isinstance(to_span_groups, str) + else to_span_groups + ) + if callable(target_span_getter) and ( + self.to_ents is None or self.to_span_groups is None + ): + raise ValueError( + "If `target_span_getter` is callable, `to_ents` or `to_span_groups` " + "cannot be inferred and must both be set manually" + ) + if ( + isinstance(target_span_getter, dict) and "labels" in target_span_getter + ) and self.labels is not None: + raise ValueError( + "You cannot set both the `labels` key of the `target_span_getter` " + "parameter and the `labels` parameter." + ) - def score(self, examples: Sequence[Example]): - return self.scorer(examples) + if isinstance(target_span_getter, list): + target_span_getter = {"span_groups": target_span_getter} + + self.target_span_getter: Union[ + SpanGetterMapping, + Callable[[Doc], Iterable[Span]], + ] = target_span_getter def post_init(self, docs: Iterable[Doc]): - # TODO, make span_getter default accessible from here - labels = dict.fromkeys(self.labels) + """ + Update the labels based on the data and the span getter, + and fills in the to_ents and to_span_groups if necessary + + Parameters + ---------- + docs + + Returns + ------- + + """ + if ( + self.labels is not None + and self.to_ents is not None + and self.to_span_groups is not None + ): + return + + inferred_labels = set() + + to_ents = [] + to_span_groups = defaultdict(lambda: []) + for doc in docs: - for ent in self.span_getter(doc): - labels[ent.label_] = None - self.update_labels(list(labels.keys())) + if callable(self.target_span_getter): + for ent in self.target_span_getter(doc): + inferred_labels.add(ent.label_) + else: + if "span_groups" in self.target_span_getter: + for group in self.target_span_getter["span_groups"]: + for span in doc.spans.get(group, ()): + if self.labels is None or span.label_ in self.labels: + inferred_labels.add(span.label_) + to_span_groups[group].append(span.label_) + elif "ents" in self.target_span_getter: + for span in doc.ents: + if self.labels is None or span.label_ in self.labels: + inferred_labels.add(span.label_) + to_ents.append(span.label_) + if self.labels is not None: + assert inferred_labels <= set(self.labels), ( + "Some inferred labels are not present in the labels " + f"passed to the component: {inferred_labels - set(self.labels)}" + ) + if inferred_labels < set(self.labels): + warnings.warn( + "Some labels passed to the trainable NER component are not " + "present in the inferred labels list: " + f"{set(self.labels) - inferred_labels}" + ) + else: + self.update_labels(sorted(inferred_labels)) + + if not self.labels: + raise ValueError( + "No labels were inferred from the data. Please check your data and " + "the `target_span_getter` parameter." + ) + + if self.to_ents is None: + self.to_ents = to_ents + self.cfg["to_ents"] = self.to_ents + if self.to_span_groups is None: + self.to_span_groups = dict(to_span_groups) + self.cfg["to_span_groups"] = self.to_span_groups def update_labels(self, labels: Sequence[str]): - n_old = len(self.labels) + original_labels = self.labels if self.labels is not None else () + n_old = len(original_labels) label_indices = dict( ( - *zip(self.labels, range(n_old)), + *zip(original_labels, range(n_old)), *zip(labels, range(n_old, n_old + len(labels))), ) ) - old_index = [label_indices[label] for label in self.labels] + old_index = [label_indices[label] for label in original_labels] new_linear = torch.nn.Linear( self.embedding.output_size, len(labels) * 5, @@ -179,7 +199,7 @@ def preprocess(self, doc): def preprocess_supervised(self, doc): targets = [[0] * len(self.labels) for _ in doc] - for ent in self.span_getter(doc): + for ent in self.get_target_spans(doc): label_idx = self.labels.index(ent.label_) if ent.start == ent.end - 1: targets[ent.start][label_idx] = 4 @@ -244,28 +264,46 @@ def forward(self, batch: NERBatchInput) -> NERBatchOutput: reduction="sum", ) else: - tags = self.crf.decode(scores, mask) + tags = self.crf.decode( + scores, mask + ) # tags = scores.argmax(-1).masked_fill(~mask.unsqueeze(-1), 0) return { "loss": loss, "tags": tags, - "mask": mask, } + def get_target_spans(self, doc): + return ( + self.target_span_getter(doc) + if callable(self.target_span_getter) + else get_spans(doc, self.target_span_getter) + ) + def postprocess(self, docs: List[Doc], batch: NERBatchOutput): spans = self.crf.tags_to_spans(batch["tags"].cpu()).tolist() ents = [[] for _ in docs] - span_groups = [{label: [] for label in self.labels} for _ in docs] + if self.to_span_groups is None or self.to_ents is None: + raise ValueError( + f"The {self.__class__.__name__} component still has to infer the " + f"`to_ents` and `to_span_groups` parameters. Please call " + f"`nlp.post_init(...)` before running it on some new data, or set " + f"both parameters manually." + ) + span_groups = [{label: [] for label in self.to_span_groups} for _ in docs] for doc_idx, start, end, label_idx in spans: label = self.labels[label_idx] span = Span(docs[doc_idx], start, end, label) - ents[doc_idx].append(span) - span_groups[doc_idx][label].append(span) + if self.to_ents is True or label in self.to_ents: + ents[doc_idx].append(span) + for group_name, group_spans in span_groups[doc_idx].items(): + if ( + self.to_span_groups[group_name] is True + or label in self.to_span_groups[group_name] + ): + span_groups[doc_idx][group_name].append(span) for doc, doc_ents, doc_span_groups in zip(docs, ents, span_groups): - doc.ents = filter_spans(doc_ents) - doc.spans.update(doc_span_groups) + if doc_ents: + doc.ents = filter_spans((*doc.ents, *doc_ents)) + if self.to_span_groups: + doc.spans.update(doc_span_groups) return docs - - def clean_gold_for_evaluation(self, gold: Doc) -> Doc: - gold.ents = [] - gold.spans.clear() - return gold diff --git a/edsnlp/scorers/__init__.py b/edsnlp/scorers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/edsnlp/scorers/ner.py b/edsnlp/scorers/ner.py new file mode 100644 index 000000000..7c8613a6b --- /dev/null +++ b/edsnlp/scorers/ner.py @@ -0,0 +1,116 @@ +from typing import Any, Dict, Iterable + +from spacy.training import Example + +from edsnlp import registry +from edsnlp.utils.span_getters import SpanGetter, get_spans + + +def ner_exact_scorer( + examples: Iterable[Example], span_getter: SpanGetter +) -> Dict[str, Any]: + """ + Scores the extracted entities that may be overlapping or nested + by looking in the spans returned by a given SpanGetter object. + + Parameters + ---------- + examples: Iterable[Example] + span_getter: SpanGetter + + Returns + ------- + Dict[str, Any] + """ + pred_spans = set() + gold_spans = set() + for eg_idx, eg in enumerate(examples): + for span in ( + span_getter(eg.predicted) + if callable(span_getter) + else get_spans(eg.predicted, span_getter) + ): + pred_spans.add((eg_idx, span.start, span.end, span.label_)) + + for span in ( + span_getter(eg.reference) + if callable(span_getter) + else get_spans(eg.reference, span_getter) + ): + gold_spans.add((eg_idx, span.start, span.end, span.label_)) + + tp = len(pred_spans & gold_spans) + + return { + "ents_p": tp / len(pred_spans) if pred_spans else float(tp == len(pred_spans)), + "ents_r": tp / len(gold_spans) if gold_spans else float(tp == len(gold_spans)), + "ents_f": 2 * tp / (len(pred_spans) + len(gold_spans)) + if pred_spans or gold_spans + else float(len(pred_spans) == len(gold_spans)), + "support": len(gold_spans), + } + + +def ner_token_scorer( + examples: Iterable[Example], span_getter: SpanGetter +) -> Dict[str, Any]: + """ + Scores the extracted entities that may be overlapping or nested + by looking in `doc.ents`, and `doc.spans`, and comparing the predicted + and gold entities at the TOKEN level. + + Parameters + ---------- + examples: Iterable[Example] + span_getter: SpanGetter + + Returns + ------- + Dict[str, Any] + """ + set(span_getter["labels"]) if "labels" in span_getter is not None else None + span_getter.get("spans_labels", None) + + pred_spans = set() + gold_spans = set() + for eg_idx, eg in enumerate(examples): + for span in ( + span_getter(eg.predicted) + if callable(span_getter) + else get_spans(eg.predicted, span_getter) + ): + for i in range(span.start, span.end): + pred_spans.add((eg_idx, i, span.label_)) + + for span in ( + span_getter(eg.reference) + if callable(span_getter) + else get_spans(eg.reference, span_getter) + ): + for i in range(span.start, span.end): + gold_spans.add((eg_idx, i, span.label_)) + + tp = len(pred_spans & gold_spans) + + return { + "ents_p": tp / len(pred_spans) if pred_spans else float(tp == len(pred_spans)), + "ents_r": tp / len(gold_spans) if gold_spans else float(tp == len(gold_spans)), + "ents_f": 2 * tp / (len(pred_spans) + len(gold_spans)) + if pred_spans or gold_spans + else float(len(pred_spans) == len(gold_spans)), + "support": len(gold_spans), + } + + +@registry.scorers.register("eds.ner_exact_scorer") +def create_ner_exact_scorer( + span_getter: SpanGetter, +): + return lambda examples: ner_exact_scorer(examples, span_getter) + + +@registry.scorers.register("eds.ner_token_scorer") +def create_ner_token_scorer( + span_getter: SpanGetter, +): + return lambda examples: ner_token_scorer(examples, span_getter) diff --git a/edsnlp/scorers/speed.py b/edsnlp/scorers/speed.py new file mode 100644 index 000000000..ea50046f7 --- /dev/null +++ b/edsnlp/scorers/speed.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, Iterable + +from spacy.training import Example + +from edsnlp import registry + + +def speed_scorer( + examples: Iterable[Example], duration: float, cfg=None +) -> Dict[str, Any]: + words_count = [len(eg.predicted) for eg in examples] + num_words = sum(words_count) + num_docs = len(words_count) + + return { + "wps": num_words / duration, + "dps": num_docs / duration, + } + + +@registry.scorers.register("speed") +def create_speed_scorer(): + return speed_scorer diff --git a/tests/pipelines/trainable/test_ner.py b/tests/pipelines/trainable/test_ner.py index f53a89112..ac396e4ad 100644 --- a/tests/pipelines/trainable/test_ner.py +++ b/tests/pipelines/trainable/test_ner.py @@ -22,6 +22,8 @@ def test_ner(ner_mode): config=dict( embedding=nlp.get_pipe("eds.transformer"), mode=ner_mode, + to_ents=True, + to_span_groups="ner-preds", ), ) doc = nlp( diff --git a/tests/training/config.cfg b/tests/training/config.cfg index 36691a822..afcee585f 100644 --- a/tests/training/config.cfg +++ b/tests/training/config.cfg @@ -17,7 +17,7 @@ kernel_sizes = [3] [components.embedding.embedding] @factory = "eds.transformer" -model = "prajjwal1/bert-tiny" +model = "hf-internal-testing/tiny-bert" window = 128 stride = 96 @@ -25,16 +25,26 @@ stride = 96 @factory = "eds.ner" mode = "joint" embedding = ${components.embedding} +target_span_getter = {"span_groups": ${vars.ml_span_groups}} +to_ents = true -[paths] +[nlp.scorers.speed] +@scorers = "eds.speed" + +[nlp.scorers.ner] +@scorers = "eds.ner_exact_scorer" +span_getter = ${components.ner.target_span_getter} + +[vars] train = "dataset" dev = "dataset" +ml_span_groups = ["ENTITY", "OTHER"] [train] nlp = ${nlp} -train_data = {"@misc": "brat_dataset", "path": ${paths.train}} -val_data = {"@misc": "brat_dataset", "path": ${paths.dev}} +train_data = {"@misc": "brat_dataset", "path": ${vars.train}} +val_data = {"@misc": "brat_dataset", "path": ${vars.dev}} max_steps = 20 -validation_interval = 4 +validation_interval = 1 batch_size = 4 lr = 3e-3 diff --git a/tests/training/test_train.py b/tests/training/test_train.py index 73de30a12..98dd1eeb1 100644 --- a/tests/training/test_train.py +++ b/tests/training/test_train.py @@ -4,7 +4,7 @@ from collections import defaultdict from itertools import chain, count, repeat from pathlib import Path -from typing import Callable, Iterable, List, Optional, Union +from typing import Callable, Iterable, List, Optional import torch from confit import Config @@ -14,6 +14,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm +import edsnlp from edsnlp.connectors.brat import BratConnector from edsnlp.core.pipeline import Pipeline from edsnlp.core.registry import registry @@ -53,43 +54,8 @@ def sample_len(idx): yield from buffer -@registry.misc.register("deft_span_getter") -def make_span_getter(): - def span_getter(doclike: Union[Doc, Span]) -> List[Span]: - """ - Get the spans of a span group that are contained inside a doclike object. - Parameters - ---------- - doclike : Union[Doc, Span] - Doclike object to act as a mask. - group : str - Group name from which to get the spans. - Returns - ------- - List[Span] - List of spans. - """ - if isinstance(doclike, Doc): - # return [ - # ent - # for group in doclike.doc.spans - # for ent in doclike.spans.get(group, ()) - # ] - return doclike.ents - else: - # return [ - # span - # for group in doclike.doc.spans - # for span in doclike.doc.spans.get(group, ()) - # if span.start >= doclike.start and span.end <= doclike.end - # ] - return doclike.ents - - return span_getter - - @registry.misc.register("brat_dataset") -def brat_dataset(path, limit: Optional[int] = None, span_getter=make_span_getter()): +def brat_dataset(path, limit: Optional[int] = None): def load(nlp): raw_data = BratConnector(path).load_brat() assert len(raw_data) > 0, "No data found in {}".format(path) @@ -105,6 +71,8 @@ def load(nlp): sentencizer = nlp.get_pipe("sentencizer") docs = [sentencizer(doc) for doc in docs] + ner = nlp.get_pipe("ner") + # Annotate entities from the raw data for doc, raw in zip(docs, raw_data): ents = [] @@ -124,20 +92,20 @@ def load(nlp): new_docs = [] for doc in docs: for sent in doc.sents: - if len(span_getter(sent)): - new_doc = sent.as_doc(copy_user_data=True) - for group in doc.spans: - new_doc.spans[group] = [ - Span( - new_doc, - span.start - sent.start, - span.end - sent.start, - span.label_, - ) - for span in doc.spans.get(group, ()) - if span.start >= sent.start and span.end <= sent.end - ] - new_docs.append(new_doc) + new_doc = sent.as_doc(copy_user_data=True) + for group in doc.spans: + new_doc.spans[group] = [ + Span( + new_doc, + span.start - sent.start, + span.end - sent.start, + span.label_, + ) + for span in doc.spans.get(group, ()) + if span.start >= sent.start and span.end <= sent.end + ] + if len(ner.get_target_spans(new_doc)): + new_docs.append(new_doc) return new_docs return load @@ -164,14 +132,9 @@ def train( val_docs = list(val_data(nlp)) # Taking the first `initialization_subset` samples to initialize the model - nlp.post_init(iter(train_docs)) # iter just to show it's possible nlp.batch_size = batch_size - # assert nlp.get_pipe('ner').embedding is nlp.get_pipe('embedding') - - # print(nlp.config.to_str()) - # Preprocessing the training dataset into a dataloader preprocessed = list(nlp.preprocess_many(train_docs, supervision=True)) dataloader = DataLoader( @@ -214,9 +177,6 @@ def train( # We will loop over the dataloader iterator = iter(dataloader) - # for name, pipe in nlp.torch_components(): - # pipe.train(False) - nlp.to(device) acc_loss = 0 @@ -266,7 +226,7 @@ def train( assert Path(output_path / "last-model").exists() - nlp.from_disk(output_path / "last-model") + nlp = edsnlp.load(output_path / "last-model") list(nlp.pipe(val_data(nlp))) @@ -277,7 +237,6 @@ def test_train(run_in_test_dir, tmp_path): set_seed(42) config = Config.from_disk("config.cfg") shutil.rmtree(tmp_path, ignore_errors=True) - print(config.to_str()) train( **config["train"].resolve(registry=registry, root=config), output_path=tmp_path,