Skip to content

Commit

Permalink
feat: refactor make_sentence_span_getter -> make_span_context_getter
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 23, 2024
1 parent 2f1c53e commit b68096e
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 17 deletions.
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
annotate every time.
- Default `eds.ner_crf` window is now set to 40 and stride set to 20, as it doesn't
affect throughput (compared to before, window set to 20) and improves accuracy.
- New default `overlap_policy='merge'` option and parameter renaming in
`eds.span_context_getter` (which replaces `eds.span_sentence_getter`)

### Fixed

Expand Down
122 changes: 106 additions & 16 deletions edsnlp/utils/span_getters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Sequence, Union

from rich.text import Span
from spacy.tokens import Doc
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
)

from pydantic import NonNegativeInt
from spacy.tokens import Doc, Span
from typing_extensions import Literal

from edsnlp import registry
from edsnlp.utils.filter import filter_spans
Expand Down Expand Up @@ -239,21 +250,100 @@ def validate(cls, value, config=None) -> SpanSetter:
]


@registry.misc.register("eds.span_sentence_getter")
class make_span_sentence_getter:
def merge_spans(spans: Iterable[Span]) -> List[Span]:
"""
Merge overlapping spans into a single span.
Parameters
----------
spans : List[Span]
List of spans to merge.
doc : Doc
Document to merge the spans on.
Returns
-------
List[Span]
Merged spans.
"""
spans = sorted(spans, key=lambda x: (x.start, x.end))
merged = []
for span in spans:
if len(merged) and span.start <= merged[-1].end:
if span.end > merged[-1].end:
merged[-1] = Span(
span.doc,
merged[-1].start,
span.end,
merged[-1].label_,
)
else:
merged.append(span)
return merged


@registry.misc.register("eds.span_context_getter")
class make_span_context_getter:
"""
Create a span context getter.
Parameters
----------
span_getter : SpanGetterArg
Span getter, i.e. for which spans to get the context.
context_words : NonNegativeInt
Minimum number of words to include on each side of the span.
context_sents : Optional[NonNegativeInt]
Minimum number of sentences to include on each side of the span:
- 0: don't use sentences to build the context.
- 1: include the sentence of the span.
- n: include n sentences on each side of the span.
By default, 0 if the document has no sentence annotations, 1 otherwise.
overlap_policy : Literal["filter", "merge"]
How to handle overlapping spans:
- "filter": remove overlapping spans.
- "merge": merge overlapping spans
"""

def __init__(
self,
span_getter: SpanGetterArg,
min_context_words: int = 0,
context_words: NonNegativeInt = 0,
context_sents: Optional[NonNegativeInt] = None,
overlap_policy: Literal["filter", "merge"] = "merge",
):
self.min_context_words = min_context_words
self.context_words = context_words
self.context_sents = context_sents
self.overlap_policy = overlap_policy
self.span_getter = span_getter

def __call__(self, doc: Doc):
ctx = self.min_context_words
spans = (
doc[min(e[0].sent.start, e.start - ctx) : max(e[-1].sent.end, e.end + ctx)]
for e in get_spans(doc, self.span_getter)
)

return filter_spans(spans)
def __call__(self, doc: Doc) -> List[Span]:
n_sents = self.context_sents
if n_sents is None:
n_sents = 0 if not doc.has_annotation("SENT_START") else 1
n_words = self.context_words

spans = []
sents = list(doc.sents) if n_sents > 1 else []
for e in get_spans(doc, self.span_getter):
min_start_sent = min_start_word = e.start - n_words
max_end_sent = max_end_word = e.end + n_words

if n_sents == 1:
sent = e.sent
min_start_sent = sent.start
max_end_sent = sent.end
else:
sent_i = sents.index(e.sent)
min_start_sent = sents[max(0, sent_i - n_sents)].start
max_end_sent = sents[min(len(sents) - 1, sent_i + n_sents)].end
start = max(0, min(min_start_word, min_start_sent))
end = min(len(doc), max(max_end_word, max_end_sent))
spans.append(doc[start:end])

if self.overlap_policy == "filter":
return filter_spans(spans)
return merge_spans(spans)
2 changes: 1 addition & 1 deletion tests/training/qlf_config.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ kernel_sizes = [3]
model = "hf-internal-testing/tiny-bert"
window = 128
stride = 96
span_getter = { "@misc": "eds.span_sentence_getter", "span_getter": ${["ents", *vars.ml_span_groups]} }
span_getter = { "@misc": "eds.span_context_getter", "span_getter": ${["ents", *vars.ml_span_groups]} }

[components.qualifier]
@factory = "eds.span_qualifier"
Expand Down
64 changes: 64 additions & 0 deletions tests/utils/test_span_getters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import edsnlp
from edsnlp.utils.span_getters import make_span_context_getter


def test_span_sentence_getter(lang):
nlp = edsnlp.blank("eds")
nlp.add_pipe("eds.normalizer")
nlp.add_pipe("eds.sentences")
nlp.add_pipe("eds.matcher", config={"terms": {"sentence": "sentence"}})
doc = nlp(
"This is a sentence. "
"This is another sentence. "
"This is a third one. "
"Last sentence."
)

span_getter = make_span_context_getter(
span_getter=["ents"],
context_words=2,
overlap_policy="merge",
)
spans = span_getter(doc)
assert [s.text for s in spans] == [
"This is a sentence. This is another sentence. This",
". Last sentence.",
]

span_getter = make_span_context_getter(
span_getter=["ents"],
context_words=2,
overlap_policy="filter",
)
spans = span_getter(doc)
assert [s.text for s in spans] == [
"This is a sentence. This",
". Last sentence.",
]

span_getter = make_span_context_getter(
span_getter=["ents"],
context_words=0,
context_sents=1,
overlap_policy="filter",
)
spans = span_getter(doc)
assert [s.text for s in spans] == [
"This is a sentence.",
"This is another sentence.",
"Last sentence.",
]

span_getter = make_span_context_getter(
span_getter=["ents"],
context_words=0,
context_sents=2,
overlap_policy="merge",
)
spans = span_getter(doc)
assert [s.text for s in spans] == [
(
"This is a sentence. This is another sentence. "
"This is a third one. Last sentence."
)
]

0 comments on commit b68096e

Please sign in to comment.