Skip to content

Commit

Permalink
refacto: redesign pipeline scorers, add input and output spans params…
Browse files Browse the repository at this point in the history
… to trainable_ner #203
  • Loading branch information
percevalw committed Dec 4, 2023
1 parent 7743dfc commit 1dd425f
Showing 7 changed files with 110 additions and 34 deletions.
9 changes: 9 additions & 0 deletions edsnlp/scorers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Any, Callable, Dict, Iterable, Union

from spacy.tokens import Doc
from spacy.training import Example

Scorer = Union[
Callable[[Iterable[Doc], Iterable[Doc]], Dict[str, Any]],
Callable[[Iterable[Example]], Dict[str, Any]],
]
18 changes: 16 additions & 2 deletions edsnlp/scorers/ner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, Iterable

import spacy.training
from spacy.training import Example

from edsnlp import registry
@@ -103,11 +104,24 @@ def ner_token_scorer(
def create_ner_exact_scorer(
span_getter: SpanGetterArg,
):
return lambda examples: ner_exact_scorer(examples, span_getter)
return lambda *args, **kwargs: ner_exact_scorer(
make_examples(*args, **kwargs), span_getter
)


@registry.scorers.register("eds.ner_token_scorer")
def create_ner_token_scorer(
span_getter: SpanGetterArg,
):
return lambda examples: ner_token_scorer(examples, span_getter)
return lambda *args: ner_token_scorer(make_examples(*args), span_getter)


def make_examples(*args):
if len(args) == 2:
return (
[spacy.training.Example(reference=g, predicted=p) for g, p in zip(*args)]
if len(args) == 2
else args[0]
)
else:
raise ValueError("Expected either a list of examples or two lists of spans")
23 changes: 0 additions & 23 deletions edsnlp/scorers/speed.py

This file was deleted.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -202,6 +202,10 @@ where = ["."]
"tables" = "edsnlp.pipelines.misc.tables.factory:create_component"
"terminology" = "edsnlp.pipelines.core.terminology.factory:create_component"

[project.entry-points."spacy_scorers"]
"eds.ner_exact_scorer" = "edsnlp.scorers.ner:create_ner_exact_scorer"
"eds.ner_token_scorer" = "edsnlp.scorers.ner:create_ner_token_scorer"

[project.entry-points."edsnlp_accelerator"]
"simple" = "edsnlp.accelerators.simple:SimpleAccelerator"
"multiprocessing" = "edsnlp.accelerators.multiprocessing:MultiprocessingAccelerator"
47 changes: 47 additions & 0 deletions tests/test_scorers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
from spacy.tokens import Span

import edsnlp
from edsnlp.scorers.ner import create_ner_exact_scorer, create_ner_token_scorer


@pytest.fixture(scope="session")
def gold_and_pred():
nlp = edsnlp.blank("eds")

gold_doc1 = nlp.make_doc("Le patient a le covid 19.")
gold_doc1.ents = [Span(gold_doc1, 4, 6, label="covid")]
gold_doc2 = nlp.make_doc("Corona: positif. Le cvid est une maladie.")
gold_doc2.ents = [
Span(gold_doc2, 0, 1, label="covid"),
Span(gold_doc2, 5, 6, label="covid"),
]

pred_doc1 = nlp.make_doc("Le patient a le covid 19.")
pred_doc1.ents = [Span(pred_doc1, 4, 6, label="covid")]
pred_doc2 = nlp.make_doc("Corona: positif. Le cvid est une maladie.")
pred_doc2.ents = [Span(pred_doc2, 0, 2, label="covid")]

return [gold_doc1, gold_doc2], [pred_doc1, pred_doc2]


def test_exact_ner_scorer(gold_and_pred):
scorer = create_ner_exact_scorer("ents")
ner_exact_score = scorer(*gold_and_pred)
assert ner_exact_score == {
"ents_p": 0.5,
"ents_r": 1 / 3,
"ents_f": 0.4,
"support": 3,
}


def test_token_ner_scorer(gold_and_pred):
scorer = create_ner_token_scorer("ents")
ner_exact_score = scorer(*gold_and_pred)
assert ner_exact_score == {
"ents_f": 0.75,
"ents_p": 0.75,
"ents_r": 0.75,
"support": 4,
}
8 changes: 3 additions & 5 deletions tests/training/config.cfg
Original file line number Diff line number Diff line change
@@ -28,11 +28,8 @@ embedding = ${components.embedding}
target_span_getter = ${vars.ml_span_groups}
infer_span_setter = true

[nlp.scorers.speed]
@scorers = "eds.speed"

[nlp.scorers.ner]
@scorers = "eds.ner_exact_scorer"
[scorer.ner]
@scorers= "eds.ner_exact_scorer"
span_getter = ${components.ner.target_span_getter}

[vars]
@@ -48,3 +45,4 @@ max_steps = 20
validation_interval = 1
batch_size = 4
lr = 3e-3
scorer = ${scorer}
35 changes: 31 additions & 4 deletions tests/training/test_train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import math
import random
import shutil
import time
from collections import defaultdict
from itertools import chain, count, repeat
from pathlib import Path
from typing import Callable, Iterable, List, Optional

import torch
from confit import Config
from confit.registry import validate_arguments
from confit.utils.random import set_seed
from spacy.tokens import Doc, Span
from torch.utils.data import DataLoader
from tqdm import tqdm

import edsnlp
@@ -20,6 +19,7 @@
from edsnlp.core.registry import registry
from edsnlp.optimization import LinearSchedule, ScheduledOptimizer
from edsnlp.pipelines.trainable.ner.ner import TrainableNER
from edsnlp.scorers import Scorer
from edsnlp.utils.collections import batchify
from edsnlp.utils.filter import filter_spans

@@ -112,6 +112,30 @@ def load(nlp):
return load


@validate_arguments
class TestScorer:
def __init__(self, **scorers: Scorer):
self.scorers = scorers

def __call__(self, nlp, docs):
clean_docs = [d.copy() for d in docs]
for d in clean_docs:
d.ents = []
d.spans.clear()
t0 = time.time()
preds = list(nlp.pipe(clean_docs))
duration = time.time() - t0
scores = {
scorer_name: scorer(docs, preds)
for scorer_name, scorer in self.scorers.items()
}
scores["speed"] = dict(
wps=sum(len(d) for d in docs) / duration,
dps=len(docs) / duration,
)
return scores


@validate_arguments
def train(
output_path: Path,
@@ -124,7 +148,10 @@ def train(
lr: float = 8e-5,
validation_interval: int = 10,
device: str = "cpu",
scorer: TestScorer = TestScorer(),
):
import torch

device = torch.device(device)
set_seed(seed)

@@ -138,7 +165,7 @@ def train(

# Preprocessing the training dataset into a dataloader
preprocessed = list(nlp.preprocess_many(train_docs, supervision=True))
dataloader = DataLoader(
dataloader = torch.utils.data.DataLoader(
preprocessed,
batch_sampler=LengthSortedBatchSampler(preprocessed, batch_size),
collate_fn=nlp.collate,
@@ -189,7 +216,7 @@ def train(
print(acc_loss / max(acc_steps, 1))
acc_loss = 0
acc_steps = 0
last_scores = nlp.score(val_docs)
last_scores = scorer(nlp, val_docs)
print(last_scores, "lr", optimizer.param_groups[0]["lr"])
if step == max_steps:
break

0 comments on commit 1dd425f

Please sign in to comment.