From 6bfa4ccbdb9f3120aadb46d13e640ea6a27f65a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 22 Feb 2024 19:19:53 +0100 Subject: [PATCH] fix: better warning for overlapping spans in eds.ner_crf --- edsnlp/pipes/trainable/ner_crf/ner_crf.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/edsnlp/pipes/trainable/ner_crf/ner_crf.py b/edsnlp/pipes/trainable/ner_crf/ner_crf.py index 50e967093..f8653f66a 100644 --- a/edsnlp/pipes/trainable/ner_crf/ner_crf.py +++ b/edsnlp/pipes/trainable/ner_crf/ner_crf.py @@ -391,7 +391,7 @@ def preprocess_supervised(self, doc): if discarded: warnings.warn( - f"Some spans in {doc._.note_id} were discarded (" + f"Some spans in were discarded {doc._.note_id} (" f"{', '.join(repr(d.text) for d in discarded)}) because they " f"were overlapping with other spans with the same label." ) @@ -481,16 +481,6 @@ def forward(self, batch: NERBatchInput) -> NERBatchOutput: # tags = scores.argmax(-1).masked_fill(~mask.unsqueeze(-1), 0) if loss is not None and loss.item() > 100000: warnings.warn("The loss is very high, this is likely a tag encoding issue.") - losses = self.crf( - scores, - mask, - batch["targets"].unsqueeze(-1) == torch.arange(5).to(scores.device), - ).view(-1) - print("LOSSES", losses.tolist()) - print( - batch["targets"].transpose(1, 2).reshape(-1, num_words)[losses.argmax()] - ) - print(batch["targets"]) return { "loss": loss, "tags": tags,