Skip to content

Commit

Permalink
refacto: redesign pipeline scorers, add input and output spans params…
Browse files Browse the repository at this point in the history
… to trainable_ner #203
  • Loading branch information
percevalw committed Aug 25, 2023
1 parent d06dde6 commit 8bb6eb2
Show file tree
Hide file tree
Showing 15 changed files with 596 additions and 371 deletions.
8 changes: 1 addition & 7 deletions docs/pipelines/trainable/ner.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,7 @@ The pipeline component can be configured using the following parameters :

<div markdown="1" class="explicit-col-width">

::: edsnlp.pipelines.trainable.nested_ner.factory.create_component
options:
only_parameters: true

The default model `eds.nested_ner_model.v1` can be configured using the following parameters :

::: edsnlp.pipelines.trainable.nested_ner.stack_crf_ner.create_model
::: edsnlp.pipelines.trainable.ner.factory.create_component
options:
only_parameters: true

Expand Down
28 changes: 6 additions & 22 deletions edsnlp/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def wrapped(self: "TorchComponent", doc: Doc):
if self.nlp._cache is None:
return fn(self, doc)
cache_id = hash((id(self), "preprocess", id(doc)))
if cache_id in self.nlp._cache:
if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, doc)
self.nlp._cache[cache_id] = res
Expand All @@ -58,7 +58,9 @@ def wrapped(self: "TorchComponent", doc: Doc):
if self.nlp._cache is None:
return fn(self, doc)
cache_id = hash((id(self), "preprocess_supervised", id(doc)))
if cache_id in self.nlp._cache.setdefault(self, {}):
if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache.setdefault(
self, {}
):
return self.nlp._cache[cache_id]
res = fn(self, doc)
self.nlp._cache[cache_id] = res
Expand All @@ -75,7 +77,7 @@ def wrapped(self: "TorchComponent", batch: Dict, device: torch.device):
cache_id = hash((id(self), "collate", hash_batch(batch)))
if self.nlp._cache is None or cache_id is None:
return fn(self, batch, device)
if cache_id in self.nlp._cache:
if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, batch, device)
self.nlp._cache[cache_id] = res
Expand All @@ -92,7 +94,7 @@ def wrapped(self: "TorchComponent", batch):
cache_id = hash((id(self), "collate", hash_batch(batch)))
if self.nlp._cache is None or cache_id is None:
return fn(self, batch)
if cache_id in self.nlp._cache:
if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, batch)
self.nlp._cache[cache_id] = res
Expand Down Expand Up @@ -298,24 +300,6 @@ def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
"""
return self.preprocess(doc)

def clean_gold_for_evaluation(self, gold: Doc) -> Doc:
"""
Clean the gold document before evaluation.
Only the attributes that are predicted by the component should be removed.
By default, this is a no-op.
Parameters
----------
gold: Doc
Gold document
Returns
-------
Doc
The document without attributes that should be predicted
"""
return gold

def pipe(self, docs: Iterable[Doc], batch_size=1) -> Iterable[Doc]:
"""
Applies the component on a collection of documents. It is recommended to use
Expand Down
134 changes: 94 additions & 40 deletions edsnlp/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: E501
import copy
import functools
import inspect
import os
import shutil
import time
Expand All @@ -25,16 +26,17 @@

import spacy
import srsly
from confit import Config
from confit.utils.collections import join_path, split_path
from confit.utils.xjson import Reference
from spacy.language import BaseDefaults
from spacy.tokenizer import Tokenizer
from spacy.tokens import Doc
from spacy.util import get_lang_class
from spacy.vocab import Vocab, create_vocab
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict

from confit import Config
from confit.utils.collections import join_path, split_path
from confit.utils.xjson import Reference
from edsnlp.core.registry import PIPE_META, CurriedFactory, FactoryMeta
from edsnlp.utils.collections import (
FrozenDict,
Expand All @@ -56,6 +58,16 @@ class CacheEnum(str, Enum):

Pipe = TypeVar("Pipe", bound=Callable[[Doc], Doc])

ScorerType = Union[
Callable[[Iterable[Doc]], Dict[str, Any]],
Callable[[Iterable[Doc], float], Dict[str, Any]],
]


class ScoringConfig(TypedDict):
pipes: NotRequired[Union[str, List[str]]]
scorer: ScorerType


class Pipeline:
"""
Expand All @@ -74,6 +86,7 @@ def __init__(
batch_size: Optional[int] = 4,
vocab_config: Type[BaseDefaults] = BaseDefaults,
meta: Dict[str, Any] = None,
scorers: Dict[str, Union[ScoringConfig, ScorerType]] = None,
):
"""
Parameters
Expand Down Expand Up @@ -118,7 +131,9 @@ def __init__(
self._path: Optional[Path] = None
self.meta = dict(meta) if meta is not None else {}
self.lang: str = lang
self.scorers = scorers or {}
self._cache: Optional[Dict] = None
self._cache_is_writeonly = False

@property
def pipeline(self) -> List[Tuple[str, Pipe]]:
Expand Down Expand Up @@ -406,9 +421,12 @@ def cache(self):
"""
Enable caching for all (trainable) components in the pipeline
"""
self._cache = {}
was_not_cached = not self._cache
if was_not_cached:
self._cache = {}
yield
self._cache = None
if was_not_cached:
self._cache = None

def torch_components(
self, disable: Sequence[str] = ()
Expand Down Expand Up @@ -487,18 +505,19 @@ def from_config(
config["nlp"]["components"] = Reference("components")
config = config["nlp"]

config = Config(config).resolve(root=root_config)
config = dict(Config(config).resolve(root=root_config))
components = config.pop("components", {})
pipeline = config.pop("pipeline", ())
tokenizer = config.pop("tokenizer", None)
disable = (config.pop("disabled", ()), disable)

nlp = Pipeline(
vocab=vocab,
create_tokenizer=config.get("tokenizer"),
lang=config["lang"],
create_tokenizer=tokenizer,
meta=meta,
**config,
)

components = config.get("components", {})
pipeline = config.get("pipeline", ())

# Since components are actually resolved as curried factories,
# we need to instantiate them here
for name, component in components.items():
Expand Down Expand Up @@ -716,46 +735,81 @@ def score(self, docs: Sequence[Doc], batch_size: int = None) -> Dict[str, Any]:
import torch
from spacy.training import Example

inputs: Sequence[Doc] = copy.deepcopy(docs)
golds: Sequence[Doc] = docs

scored_components = {}

# Predicting intermediate steps
preds = defaultdict(lambda: [])
if batch_size is None:
batch_size = self.batch_size
total_duration = 0

scorers_by_pipes = defaultdict(lambda: {})
for scorer_name, scorer in self.scorers.items():
if isinstance(scorer, dict) and "scorer" in scorer:
pipe_names = scorer.get("pipes", self.pipe_names)
actual_scorer = scorer["scorer"]
if isinstance(pipe_names, str):
pipe_names = [pipe_names]
if pipe_names is None:
pipe_names = self.pipe_names
else:
pipe_names = self.pipe_names
actual_scorer = scorer
scorers_by_pipes[tuple(pipe_names)][scorer_name] = actual_scorer

speed_metric_names = {
name
for _, scorers_group in scorers_by_pipes.items()
for name, scorer in scorers_group.items()
if "duration" in inspect.signature(scorer).parameters
}
pipes_to_duration = {
pipe_names: 0.0
for pipe_names in scorers_by_pipes.keys()
if speed_metric_names & set(scorers_by_pipes[pipe_names])
}

with self.train(False), torch.no_grad(): # type: ignore
for batch in batchify(
tqdm(inputs, "Scoring components"), batch_size=batch_size
for gold_batch in batchify(
tqdm(docs, "Scoring components"), batch_size=batch_size
):
with self.cache():
for name, pipe in self.pipeline[::-1]:
if hasattr(pipe, "clean_gold_for_evaluation"):
batch = [
pipe.clean_gold_for_evaluation(doc) for doc in batch
]
for pipe_names in scorers_by_pipes.keys():
timed = speed_metric_names & set(scorers_by_pipes[pipe_names])

if timed:
self._cache_is_writeonly = True

batch = copy.deepcopy(gold_batch)

t0 = time.time()
if hasattr(pipe, "batch_process"):
batch = pipe.batch_process(batch)
else:
batch = [pipe(doc) for doc in batch]
total_duration += time.time() - t0

if getattr(pipe, "score", None) is not None:
scored_components[name] = pipe
preds[name].extend(copy.deepcopy(batch))
for pipe_name in pipe_names:
pipe = self.get_pipe(pipe_name)
if hasattr(pipe, "batch_process"):
batch = pipe.batch_process(batch)
else:
batch = [pipe(doc) for doc in batch]

metrics: Dict[str, Any] = {
"speed": len(inputs) / total_duration,
}
for name, pipe in scored_components.items():
metrics[name] = pipe.score(
[Example(p, g) for p, g in zip(preds[name], golds)]
)
t1 = time.time()

if timed:
pipes_to_duration[pipe_names] += t1 - t0
self._cache_is_writeonly = False

preds[pipe_names].extend(batch)

results: Dict[str, Any] = {}
for pipe_names, preds in preds.items():
for scorer_name, scorer in scorers_by_pipes[pipe_names].items():
if scorer_name in speed_metric_names:
results[scorer_name] = scorer(
[Example(p, g) for p, g in zip(preds, docs)],
duration=pipes_to_duration[pipe_names],
)
else:
results[scorer_name] = scorer(
[Example(p, g) for p, g in zip(preds, docs)],
)

return metrics
return results

def to_disk(
self, path: Union[str, Path], *, exclude: Sequence[str] = FrozenList()
Expand Down
11 changes: 0 additions & 11 deletions edsnlp/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class CurriedFactory:
def __init__(self, func, kwargs):
self.kwargs = kwargs
self.factory = func
# self.factory_name = factory_name
self.instantiated = None

def instantiate(
Expand Down Expand Up @@ -85,16 +84,6 @@ def instantiate(
**kwargs,
}
)
# Config._store_resolved(
# obj.instantiated,
# Config(
# {
# "@factory": obj.factory_name,
# **kwargs,
# }
# ),
# )
# PIPE_META[obj.instantiated] = obj.meta
return obj.instantiated
elif isinstance(obj, dict):
return {
Expand Down
2 changes: 1 addition & 1 deletion edsnlp/pipelines/trainable/ner/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .factory import create_component, create_ner_exact_scorer
from .factory import create_component
Loading

0 comments on commit 8bb6eb2

Please sign in to comment.