From 54a41cdc3287290d328936fd43739b10b50cf7c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 19 Jul 2024 18:43:05 +0200 Subject: [PATCH] fix: support context_words=0 in span_context_getter --- changelog.md | 2 +- edsnlp/utils/span_getters.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/changelog.md b/changelog.md index a2de6a87c..bbe0f8e00 100644 --- a/changelog.md +++ b/changelog.md @@ -25,7 +25,7 @@ - Support mixed precision in `eds.text_cnn` and `eds.ner_crf` components - Support pre-quantization (<4.30) transformers versions - Verify that all batches are non empty -- Fix `span_context_getter` for `context_sents` > 2 and support assymetric contexts +- Fix `span_context_getter` for `context_words` = 0, `context_sents` > 2 and support assymetric contexts ## v0.12.3 diff --git a/edsnlp/utils/span_getters.py b/edsnlp/utils/span_getters.py index aa2e54429..ce07acc61 100644 --- a/edsnlp/utils/span_getters.py +++ b/edsnlp/utils/span_getters.py @@ -307,11 +307,13 @@ def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]: n_words_left = self.n_words_left n_words_right = self.n_words_right - start = span.start - n_words_left - end = span.end + n_words_right + start = max(0, span.start - n_words_left) + end = min(len(span.doc), span.end + n_words_right) n_sents_max = max(n_sents_left, n_sents_right) if n_sents_max > 0: + min_start_sent = start + max_end_sent = end if n_sents_left == 1: sent = span.sent min_start_sent = sent.start @@ -325,10 +327,7 @@ def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]: max_end_sent = sents[ min(len(sents) - 1, sent_i + n_sents_right - 1) ].end - start = max(0, min(start, min_start_sent)) - end = min(len(span.doc), max(end, max_end_sent)) - else: - start = max(0, start) - end = min(len(span.doc), end) + start = min(start, min_start_sent) + end = max(end, max_end_sent) return span.doc[start:end]