From 7fd284a79a32e9e891c6603045514fc816106d5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Tue, 19 Nov 2024 17:42:12 +0100 Subject: [PATCH 1/2] feat: added conll format --- changelog.md | 1 + docs/assets/stylesheets/extra.css | 1 + docs/data/conll.md | 38 ++++ edsnlp/data/__init__.py | 1 + edsnlp/data/conll.py | 256 +++++++++++++++++++++++++ edsnlp/data/converters.py | 50 +++++ edsnlp/utils/file_system.py | 2 + mkdocs.yml | 1 + pyproject.toml | 2 + tests/data/test_conll.py | 38 ++++ tests/training/rhapsodie_sample.conllu | 67 +++++++ 11 files changed, 457 insertions(+) create mode 100644 docs/data/conll.md create mode 100644 edsnlp/data/conll.py create mode 100644 tests/data/test_conll.py create mode 100644 tests/training/rhapsodie_sample.conllu diff --git a/changelog.md b/changelog.md index 4e2f7f5f5..2d1998193 100644 --- a/changelog.md +++ b/changelog.md @@ -9,6 +9,7 @@ - Log the training config at the beginning of the trainings - Support a specific model output dir path for trainings (`output_model_dir`), and whether to save the model or not (`save_model`) - Specify whether to log the validation results or not (`logger=False`) +- Added support for the CoNLL format with `edsnlp.data.read_conll` and with a specific `eds.conll_dict2doc` converter ### Fixed diff --git a/docs/assets/stylesheets/extra.css b/docs/assets/stylesheets/extra.css index 41fbec3ed..ef6c443e4 100644 --- a/docs/assets/stylesheets/extra.css +++ b/docs/assets/stylesheets/extra.css @@ -188,4 +188,5 @@ a.discrete-link { .sourced-heading > a { font-size: 1rem; + align-content: center; } diff --git a/docs/data/conll.md b/docs/data/conll.md new file mode 100644 index 000000000..063943c26 --- /dev/null +++ b/docs/data/conll.md @@ -0,0 +1,38 @@ +# CoNLL + +??? abstract "TLDR" + + ```{ .python .no-check } + import edsnlp + + stream = edsnlp.data.read_conll(path) + stream = stream.map_pipeline(nlp) + ``` + +You can easily integrate CoNLL formatted files into your project by using EDS-NLP's CoNLL reader. + +There are many CoNLL formats corresponding to different shared tasks, but one of the most common is the CoNLL-U format, which is used for dependency parsing. In CoNLL files, each line corresponds to a token and contains various columns with information about the token, such as its index, form, lemma, POS tag, and dependency relation. + +EDS-NLP lets you specify the name of the `columns` if they are different from the default CoNLL-U format. If the `columns` parameter is unset, the reader looks for a comment containing `# global.columns` to infer the column names. Otherwise, the columns are + +``` +ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC +``` + +A typical CoNLL file looks like this: + +```{ title="sample.conllu" } +1 euh euh INTJ _ _ 5 discourse _ SpaceAfter=No +2 , , PUNCT _ _ 1 punct _ _ +3 il lui PRON _ Gender=Masc|Number=Sing|Person=3|PronType=Prs 5 expl:subj _ _ +... +``` + +## Reading CoNLL files {: #edsnlp.data.conll.read_conll } + +::: edsnlp.data.conll.read_conll + options: + heading_level: 3 + show_source: false + show_toc: false + show_bases: false diff --git a/edsnlp/data/__init__.py b/edsnlp/data/__init__.py index 44eb620fd..c6cbc0593 100644 --- a/edsnlp/data/__init__.py +++ b/edsnlp/data/__init__.py @@ -7,6 +7,7 @@ from .base import from_iterable, to_iterable from .standoff import read_standoff, write_standoff from .brat import read_brat, write_brat + from .conll import read_conll from .json import read_json, write_json from .parquet import read_parquet, write_parquet from .spark import from_spark, to_spark diff --git a/edsnlp/data/conll.py b/edsnlp/data/conll.py new file mode 100644 index 000000000..3cf5b8d90 --- /dev/null +++ b/edsnlp/data/conll.py @@ -0,0 +1,256 @@ +import os +import random +import warnings +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + +from fsspec import filesystem as fsspec +from loguru import logger +from typing_extensions import Literal + +from edsnlp import registry +from edsnlp.core.stream import Stream +from edsnlp.data.base import FileBasedReader +from edsnlp.data.converters import FILENAME, get_dict2doc_converter +from edsnlp.utils.collections import shuffle +from edsnlp.utils.file_system import FileSystem, normalize_fs_path, walk_match +from edsnlp.utils.stream_sentinels import DatasetEndSentinel +from edsnlp.utils.typing import AsList + +LOCAL_FS = fsspec("file") + +DEFAULT_COLUMNS = [ + "ID", + "FORM", + "LEMMA", + "UPOS", + "XPOS", + "FEATS", + "HEAD", + "DEPREL", + "DEPS", + "MISC", +] + + +def parse_conll( + path: str, + cols: Optional[List[str]] = None, + fs: FileSystem = LOCAL_FS, +) -> Iterable[Dict]: + """ + Load a .conll file and return a dictionary with the text, words, and entities. + This expects the file to contain multiple sentences, split into words, each one + described in a line. Each sentence is separated by an empty line. + + If possible, looks for a `#global.columns` comment at the start of the file to + extract the column names. + + Examples: + + ```text + ... + 11 jeune jeune ADJ _ Number=Sing 12 amod _ _ + 12 fille fille NOUN _ Gender=Fem|Number=Sing 5 obj _ _ + 13 qui qui PRON _ PronType=Rel 14 nsubj _ _ + ... + ``` + + Parameters + ---------- + path: str + Path or glob path of the brat text file (.txt, not .ann) + cols: Optional[List[str]] + List of column names to use. If None, the first line of the file will be used + fs: FileSystem + Filesystem to use + + Returns + ------- + Iterator[Dict] + """ + with fs.open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + + if cols is None: + try: + cols = next( + line.split("=")[1].strip().split() + for line in lines + if line.strip("# ").startswith("global.columns") + ) + except StopIteration: + cols = DEFAULT_COLUMNS + warnings.warn( + f"No #global.columns comment found in the CoNLL file. " + f"Using default {cols}" + ) + + doc = {"words": []} + for line in lines: + line = line.strip() + if not line: + if doc["words"]: + yield doc + doc = {"words": []} + continue + if line.startswith("#"): + continue + parts = line.split("\t") + word = {k: v for k, v in zip(cols, parts) if v != "_"} + doc["words"].append(word) + + if doc["words"]: + yield doc + + +class ConllReader(FileBasedReader): + DATA_FIELDS = () + + def __init__( + self, + path: Union[str, Path], + *, + columns: Optional[List[str]] = None, + filesystem: Optional[FileSystem] = None, + loop: bool = False, + shuffle: Literal["dataset", False] = False, + seed: Optional[int] = None, + ): + super().__init__() + self.shuffle = shuffle + self.emitted_sentinels = {"dataset"} + self.rng = random.Random(seed) + self.loop = loop + self.fs, self.path = normalize_fs_path(filesystem, path) + self.columns = columns + + files = walk_match(self.fs, self.path, ".*[.]conllu?") + self.files = sorted(files) + assert len(self.files), f"No .conll files found in the directory {self.path}" + logger.info(f"The directory contains {len(self.files)} .conll files.") + + def read_records(self) -> Iterable[Any]: + while True: + files = self.files + if self.shuffle: + files = shuffle(files, self.rng) + for item in files: + for anns in parse_conll(item, cols=self.columns, fs=self.fs): + anns[FILENAME] = os.path.relpath(item, self.path).rsplit(".", 1)[0] + anns["doc_id"] = anns[FILENAME] + yield anns + yield DatasetEndSentinel() + if not self.loop: + break + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"path={self.path!r}, " + f"shuffle={self.shuffle}, " + f"loop={self.loop})" + ) + + +# No writer for CoNLL format yet + + +@registry.readers.register("conll") +def read_conll( + path: Union[str, Path], + *, + columns: Optional[List[str]] = None, + converter: Optional[AsList[Union[str, Callable]]] = ["conll"], + filesystem: Optional[FileSystem] = None, + shuffle: Literal["dataset", False] = False, + seed: Optional[int] = None, + loop: bool = False, + **kwargs, +) -> Stream: + """ + The ConllReader (or `edsnlp.data.read_conll`) reads a file or directory of CoNLL + files and yields documents. + + The raw output (i.e., by setting `converter=None`) will be in the following form + for a single doc: + + ``` + { + "words": [ + {"ID": "1", "FORM": ...}, + ... + ], + } + ``` + + Example + ------- + ```{ .python .no-check } + + import edsnlp + + nlp = edsnlp.blank("eds") + nlp.add_pipe(...) + doc_iterator = edsnlp.data.read_conll("path/to/conll/file/or/directory") + annotated_docs = nlp.pipe(doc_iterator) + ``` + + !!! note "Generator vs list" + + `edsnlp.data.read_conll` returns a + [Stream][edsnlp.core.stream.Stream]. + To iterate over the documents multiple times efficiently or to access them by + index, you must convert it to a list : + + ```{ .python .no-check } + docs = list(edsnlp.data.read_conll("path/to/conll/file/or/directory")) + ``` + + Parameters + ---------- + path : Union[str, Path] + Path to the directory containing the CoNLL files (will recursively look for + files in subdirectories). + columns: Optional[List[str]] + List of column names to use. If None, will try to extract to look for a + `#global.columns` comment at the start of the file to extract the column names. + shuffle: Literal["dataset", False] + Whether to shuffle the data. If "dataset", the whole dataset will be shuffled + before starting iterating on it (at the start of every epoch if looping). + seed: Optional[int] + The seed to use for shuffling. + loop: bool + Whether to loop over the data indefinitely. + nlp : Optional[PipelineProtocol] + The pipeline object (optional and likely not needed, prefer to use the + `tokenizer` directly argument instead). + tokenizer : Optional[spacy.tokenizer.Tokenizer] + The tokenizer instance used to tokenize the documents. Likely not needed since + by default it uses the current context tokenizer : + + - the tokenizer of the next pipeline run by `.map_pipeline` in a + [Stream][edsnlp.core.stream.Stream]. + - or the `eds` tokenizer by default. + converter : Optional[AsList[Union[str, Callable]]] + Converter to use to convert the documents to dictionary objects. + filesystem: Optional[FileSystem] = None, + The filesystem to use to write the files. If None, the filesystem will be + inferred from the path (e.g. `s3://` will use S3). + """ + + data = Stream( + reader=ConllReader( + path, + columns=columns, + filesystem=filesystem, + loop=loop, + shuffle=shuffle, + seed=seed, + ) + ) + if converter: + for conv in converter: + conv, kwargs = get_dict2doc_converter(conv, kwargs) + data = data.map(conv, kwargs=kwargs) + return data diff --git a/edsnlp/data/converters.py b/edsnlp/data/converters.py index 1bf1e6d2b..2059d3c8c 100644 --- a/edsnlp/data/converters.py +++ b/edsnlp/data/converters.py @@ -5,6 +5,7 @@ """ import inspect +import warnings from copy import copy from types import FunctionType from typing import ( @@ -19,6 +20,7 @@ ) import pydantic +import spacy from confit.registry import ValidatedFunction from spacy.tokenizer import Tokenizer from spacy.tokens import Doc, Span @@ -379,6 +381,54 @@ def __call__(self, doc): return obj +@registry.factory.register("eds.conll_dict2doc", spacy_compatible=False) +class ConllDict2DocConverter: + """ + TODO + """ + + def __init__( + self, + *, + tokenizer: Optional[Tokenizer] = None, + ): + self.tokenizer = tokenizer + + def __call__(self, obj, tokenizer=None): + tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer + vocab = tok.vocab + words_data = [word for word in obj["words"] if "-" not in word["ID"]] + words = [word["FORM"] for word in words_data] + spaces = ["SpaceAfter=No" not in w.get("MISC", "") for w in words_data] + doc = Doc(vocab, words=words, spaces=spaces) + + id_to_word = {word["ID"]: i for i, word in enumerate(words_data)} + for word_data, word in zip(words_data, doc): + for key, value in word_data.items(): + if key in ("ID", "FORM", "MISC"): + pass + elif key == "LEMMA": + word.lemma_ = value + elif key == "UPOS": + word.pos_ = value + elif key == "XPOS": + word.tag_ = value + elif key == "FEATS": + word.morph = spacy.tokens.morphanalysis.MorphAnalysis( + tok.vocab, + dict(feat.split("=") for feat in value.split("|")), + ) + elif key == "HEAD": + if value != "0": + word.head = doc[id_to_word[value]] + elif key == "DEPREL": + word.dep_ = value + else: + warnings.warn(f"Unused key {key} in CoNLL dict, ignoring it.") + + return doc + + @registry.factory.register("eds.omop_dict2doc", spacy_compatible=False) class OmopDict2DocConverter: """ diff --git a/edsnlp/utils/file_system.py b/edsnlp/utils/file_system.py index b93911e8e..29179cf2f 100644 --- a/edsnlp/utils/file_system.py +++ b/edsnlp/utils/file_system.py @@ -24,6 +24,8 @@ def walk_match( root: str, file_pattern: str, ) -> list: + if fs.isfile(root): + return [root] return [ os.path.join(dirpath, f) for dirpath, dirnames, files in fs.walk(root) diff --git a/mkdocs.yml b/mkdocs.yml index 023a92c4f..786e26d55 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -132,6 +132,7 @@ nav: - Data Connectors: - data/index.md - data/standoff.md + - data/conll.md - data/json.md - data/parquet.md - data/pandas.md diff --git a/pyproject.toml b/pyproject.toml index eb8e7dfb7..43043b865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,6 +209,7 @@ where = ["."] "eds.split" = "edsnlp.pipes.misc.split.split:Split" "eds.standoff_dict2doc" = "edsnlp.data.converters:StandoffDict2DocConverter" "eds.standoff_doc2dict" = "edsnlp.data.converters:StandoffDoc2DictConverter" +"eds.conll_dict2doc" = "edsnlp.data.converters:ConllDict2DocConverter" "eds.omop_dict2doc" = "edsnlp.data.converters:OmopDict2DocConverter" "eds.omop_doc2dict" = "edsnlp.data.converters:OmopDoc2DictConverter" "eds.ents_doc2dict" = "edsnlp.data.converters:EntsDoc2DictConverter" @@ -295,6 +296,7 @@ where = ["."] "parquet" = "edsnlp.data:read_parquet" "standoff" = "edsnlp.data:read_standoff" "brat" = "edsnlp.data:read_brat" # alias for standoff +"conll" = "edsnlp.data:read_conll" [project.entry-points."edsnlp_writers"] "spark" = "edsnlp.data:to_spark" diff --git a/tests/data/test_conll.py b/tests/data/test_conll.py new file mode 100644 index 000000000..049806a18 --- /dev/null +++ b/tests/data/test_conll.py @@ -0,0 +1,38 @@ +from itertools import islice +from pathlib import Path + +import pytest +from typing_extensions import Literal + +import edsnlp + + +@pytest.mark.parametrize("num_cpu_workers", [0, 2]) +@pytest.mark.parametrize("shuffle", ["dataset"]) +def test_read_shuffle_loop( + num_cpu_workers: int, + shuffle: Literal["dataset", "fragment"], +): + input_file = ( + Path(__file__).parent.parent.resolve() / "training" / "rhapsodie_sample.conllu" + ) + notes = edsnlp.data.read_conll( + input_file, + shuffle=shuffle, + seed=42, + loop=True, + ).set_processing(num_cpu_workers=num_cpu_workers) + notes = list(islice(notes, 6)) + assert len(notes) == 6 + # 32 ce ce PRON _ Gender=Masc|Number=Sing|Person=3|PronType=Dem 30 obl:arg _ _ # noqa: E501 + word_attrs = { + "text": "ce", + "lemma_": "ce", + "pos_": "PRON", + "dep_": "obl:arg", + "morph": "Gender=Masc|Number=Sing|Person=3|PronType=Dem", + "head": "profité", + } + word = notes[0][31] + for attr, val in word_attrs.items(): + assert str(getattr(word, attr)) == val diff --git a/tests/training/rhapsodie_sample.conllu b/tests/training/rhapsodie_sample.conllu new file mode 100644 index 000000000..76b0e42d2 --- /dev/null +++ b/tests/training/rhapsodie_sample.conllu @@ -0,0 +1,67 @@ +# global.columns = ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC +# macrosyntax = "euh" il y avait ( donc ) une "euh" jeune fille { qui regardait dans { une boutique | apparemment une pâtisserie } | qui semblait avoir faim | qui { a profité de ce que le livreur s'éloigne pour "euh" voler { un | une } baguette "euh" | a rencontré ( donc ) Charlot à ce moment-là | lui est rentrée dedans } } // +# sent_id = Rhap_M0024-1 +# text = euh, il y avait donc une, euh, jeune fille qui regardait dans une boutique, apparemment une pâtisserie, qui semblait avoir faim, qui a profité de ce que le livreur s'éloigne pour, euh, voler un, une baguette, euh, a rencontré donc Charlot à ce moment-là, lui est rentrée dedans. +1 euh euh INTJ _ _ 5 discourse _ SpaceAfter=No +2 , , PUNCT _ _ 1 punct _ _ +3 il lui PRON _ Gender=Masc|Number=Sing|Person=3|PronType=Prs 5 expl:subj _ _ +4 y y PRON _ Person=3|PronType=Prs 5 expl:comp _ _ +5 avait avoir VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Imp|VerbForm=Fin 0 root _ _ +6 donc donc ADV _ _ 5 discourse _ _ +7 une un DET _ Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 12 det _ SpaceAfter=No +8 , , PUNCT _ _ 9 punct _ _ +9 euh euh INTJ _ _ 7 discourse _ SpaceAfter=No +10 , , PUNCT _ _ 7 punct _ _ +11 jeune jeune ADJ _ Number=Sing 12 amod _ _ +12 fille fille NOUN _ Gender=Fem|Number=Sing 5 obj _ _ +13 qui qui PRON _ PronType=Rel 14 nsubj _ _ +14 regardait regarder VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Imp|VerbForm=Fin 12 acl:relcl _ _ +15 dans dans ADP _ _ 17 case _ _ +16 une un DET _ Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 17 det _ _ +17 boutique boutique NOUN _ Gender=Fem|Number=Sing 14 obl:arg _ SpaceAfter=No +18 , , PUNCT _ _ 21 punct _ _ +19 apparemment apparemment ADV _ _ 21 advmod _ _ +20 une un DET _ Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 21 det _ _ +21 pâtisserie pâtisserie NOUN _ Gender=Fem|Number=Sing 17 appos _ SpaceAfter=No +22 , , PUNCT _ _ 24 punct _ _ +23 qui qui PRON _ PronType=Rel 24 nsubj _ _ +24 semblait sembler VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Imp|VerbForm=Fin 14 conj _ _ +25 avoir avoir VERB _ VerbForm=Inf 24 xcomp _ Subject=SubjRaising +26 faim faim NOUN _ Gender=Fem|Number=Sing 25 obj _ SpaceAfter=No +27 , , PUNCT _ _ 30 punct _ _ +28 qui qui PRON _ PronType=Rel 30 nsubj _ _ +29 a avoir AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 30 aux:tense _ _ +30 profité profiter VERB _ Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 14 conj _ _ +31 de de ADP _ _ 32 case _ _ +32 ce ce PRON _ Gender=Masc|Number=Sing|Person=3|PronType=Dem 30 obl:arg _ _ +33 que que SCONJ _ _ 37 mark _ _ +34 le le DET _ Definite=Def|Gender=Masc|Number=Sing|PronType=Art 35 det _ _ +35 livreur livreur NOUN _ Gender=Masc|Number=Sing 37 nsubj _ _ +36 s' soi PRON _ Person=3|PronType=Prs 37 obj _ SpaceAfter=No +37 éloigne éloigner VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 32 acl _ _ +38 pour pour ADP _ _ 42 mark _ SpaceAfter=No +39 , , PUNCT _ _ 42 punct _ _ +40 euh euh INTJ _ _ 42 discourse _ SpaceAfter=No +41 , , PUNCT _ _ 40 punct _ _ +42 voler voler VERB _ VerbForm=Inf 30 advcl _ Subject=NoRaising +43 un un DET _ Definite=Ind|Gender=Masc|Number=Sing|PronType=Art 45 reparandum _ SpaceAfter=No +44 , , PUNCT _ _ 43 punct _ _ +45 une un DET _ Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 46 det _ _ +46 baguette baguette NOUN _ Gender=Fem|Number=Sing 42 obj _ SpaceAfter=No +47 , , PUNCT _ _ 48 punct _ _ +48 euh euh INTJ _ _ 46 discourse _ SpaceAfter=No +49 , , PUNCT _ _ 51 punct _ _ +50 a avoir AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 51 aux:tense _ _ +51 rencontré rencontrer VERB _ Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 30 conj _ _ +52 donc donc ADV _ _ 53 discourse _ _ +53 Charlot Charlot PROPN _ _ 51 obj _ _ +54 à à ADP _ _ 56 case _ _ +55 ce ce DET _ Gender=Masc|Number=Sing|PronType=Dem 56 det _ _ +56 moment moment NOUN _ Gender=Masc|Number=Sing 51 obl:mod _ SpaceAfter=No +57 -là là ADV _ _ 56 advmod _ SpaceAfter=No +58 , , PUNCT _ _ 61 punct _ _ +59 lui lui PRON _ Gender=Masc|Number=Sing|Person=3|PronType=Prs 61 iobj _ _ +60 est être AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 61 aux:tense _ _ +61 rentrée rentrer VERB _ Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 51 conj _ _ +62 dedans dedans ADV _ _ 61 obj _ SpaceAfter=No +63 . . PUNCT _ _ 5 punct _ _ From 5dcd4db684b3a68b2d11950dcbb41ca819d559f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Tue, 19 Nov 2024 17:46:56 +0100 Subject: [PATCH 2/2] feat: added trainable biaffine dependency parser and metrics --- changelog.md | 1 + .../trainable/biaffine-dependency-parser.md | 8 + docs/pipes/trainable/index.md | 17 +- docs/references.bib | 23 + edsnlp/metrics/dep_parsing.py | 55 ++ edsnlp/pipes/__init__.py | 1 + .../trainable/biaffine_dep_parser/__init__.py | 1 + .../biaffine_dep_parser.py | 672 ++++++++++++++++++ .../trainable/biaffine_dep_parser/factory.py | 8 + edsnlp/pipes/trainable/ner_crf/ner_crf.py | 9 - edsnlp/training/trainer.py | 12 + mkdocs.yml | 1 + pyproject.toml | 26 +- tests/training/dep_parser_config.yml | 59 ++ tests/training/test_train.py | 27 +- 15 files changed, 890 insertions(+), 30 deletions(-) create mode 100644 docs/pipes/trainable/biaffine-dependency-parser.md create mode 100644 edsnlp/metrics/dep_parsing.py create mode 100644 edsnlp/pipes/trainable/biaffine_dep_parser/__init__.py create mode 100644 edsnlp/pipes/trainable/biaffine_dep_parser/biaffine_dep_parser.py create mode 100644 edsnlp/pipes/trainable/biaffine_dep_parser/factory.py create mode 100644 tests/training/dep_parser_config.yml diff --git a/changelog.md b/changelog.md index 2d1998193..52e66d11e 100644 --- a/changelog.md +++ b/changelog.md @@ -10,6 +10,7 @@ - Support a specific model output dir path for trainings (`output_model_dir`), and whether to save the model or not (`save_model`) - Specify whether to log the validation results or not (`logger=False`) - Added support for the CoNLL format with `edsnlp.data.read_conll` and with a specific `eds.conll_dict2doc` converter +- Added a Trainable Biaffine Dependency Parser (`eds.biaffine_dep_parser`) component and metrics ### Fixed diff --git a/docs/pipes/trainable/biaffine-dependency-parser.md b/docs/pipes/trainable/biaffine-dependency-parser.md new file mode 100644 index 000000000..828649a90 --- /dev/null +++ b/docs/pipes/trainable/biaffine-dependency-parser.md @@ -0,0 +1,8 @@ +# Trainable Biaffine Dependency Parser {: #edsnlp.pipes.trainable.biaffine_dep_parser.factory.create_component } + +::: edsnlp.pipes.trainable.biaffine_dep_parser.factory.create_component + options: + heading_level: 2 + show_bases: false + show_source: false + only_class_level: true diff --git a/docs/pipes/trainable/index.md b/docs/pipes/trainable/index.md index a77a109cd..3444be280 100644 --- a/docs/pipes/trainable/index.md +++ b/docs/pipes/trainable/index.md @@ -8,13 +8,14 @@ All trainable components implement the [`TorchComponent`][edsnlp.core.torch_comp -| Name | Description | -|-----------------------|-----------------------------------------------------------------------| -| `eds.transformer` | Embed text with a transformer model | -| `eds.text_cnn` | Contextualize embeddings with a CNN | -| `eds.span_pooler` | A span embedding component that aggregates word embeddings | -| `eds.ner_crf` | A trainable component to extract entities | -| `eds.span_classifier` | A trainable component for multi-class multi-label span classification | -| `eds.span_linker` | A trainable entity linker (i.e. to a list of concepts) | +| Name | Description | +|---------------------------|-----------------------------------------------------------------------| +| `eds.transformer` | Embed text with a transformer model | +| `eds.text_cnn` | Contextualize embeddings with a CNN | +| `eds.span_pooler` | A span embedding component that aggregates word embeddings | +| `eds.ner_crf` | A trainable component to extract entities | +| `eds.span_classifier` | A trainable component for multi-class multi-label span classification | +| `eds.span_linker` | A trainable entity linker (i.e. to a list of concepts) | +| `eds.biaffine_dep_parser` | A trainable biaffine dependency parser | diff --git a/docs/references.bib b/docs/references.bib index 00d5094e5..0e2060517 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -161,3 +161,26 @@ @article{petitjean_2024 url = {https://doi.org/10.1093/jamia/ocae069}, eprint = {https://academic.oup.com/jamia/article-pdf/31/6/1280/57769016/ocae069.pdf}, } + +@misc{dozat2017deepbiaffineattentionneural, + title={Deep Biaffine Attention for Neural Dependency Parsing}, + author={Timothy Dozat and Christopher D. Manning}, + year={2017}, + eprint={1611.01734}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/1611.01734}, +} + +@inproceedings{grobol:hal-03223424, + title = {{Analyse en dépendances du français avec des plongements contextualisés}}, + author = {Grobol, Loïc and Crabbé, Benoît}, + url = {https://hal.archives-ouvertes.fr/hal-03223424}, + year = {2021}, + booktitle = {{Actes de la 28ème Conférence sur le Traitement Automatique des Langues Naturelles}}, + eventtitle = {{TALN-RÉCITAL 2021}}, + venue = {Lille, France}, + pdf = {https://hal.archives-ouvertes.fr/hal-03223424/file/HOPS_final.pdf}, + hal_id = {hal-03223424}, + hal_version = {v1}, +} diff --git a/edsnlp/metrics/dep_parsing.py b/edsnlp/metrics/dep_parsing.py new file mode 100644 index 000000000..5247a483f --- /dev/null +++ b/edsnlp/metrics/dep_parsing.py @@ -0,0 +1,55 @@ +from typing import Any, Optional + +from edsnlp import registry +from edsnlp.metrics import Examples, make_examples, prf + + +def dependency_parsing_metric( + examples: Examples, + filter_expr: Optional[str] = None, +): + """ + Compute the UAS and LAS scores for dependency parsing. + + Parameters + ---------- + examples : Examples + The examples to score, either a tuple of (golds, preds) or a list of + spacy.training.Example objects + filter_expr : Optional[str] + The filter expression to use to filter the documents + + Returns + ------- + Dict[str, float] + """ + items = { + "uas": (set(), set()), + "las": (set(), set()), + } + examples = make_examples(examples) + if filter_expr is not None: + filter_fn = eval(f"lambda doc: {filter_expr}") + examples = [eg for eg in examples if filter_fn(eg.reference)] + + for eg_idx, eg in enumerate(examples): + for token in eg.reference: + items["uas"][0].add((eg_idx, token.i, token.head.i)) + items["las"][0].add((eg_idx, token.i, token.head.i, token.dep_)) + + for token in eg.predicted: + items["uas"][1].add((eg_idx, token.i, token.head.i)) + items["las"][1].add((eg_idx, token.i, token.head.i, token.dep_)) + + return {name: prf(pred, gold)["f"] for name, (pred, gold) in items.items()} + + +@registry.metrics.register("eds.dep_parsing") +class DependencyParsingMetric: + def __init__(self, filter_expr: Optional[str] = None): + self.filter_expr = filter_expr + + __init__.__doc__ = dependency_parsing_metric.__doc__ + + def __call__(self, *examples: Any): + return dependency_parsing_metric(examples, self.filter_expr) diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index 8ccdc1d3a..02a2a0489 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -75,6 +75,7 @@ from .qualifiers.reported_speech.factory import create_component as reported_speech from .qualifiers.reported_speech.factory import create_component as rspeech from .trainable.ner_crf.factory import create_component as ner_crf + from .trainable.biaffine_dep_parser.factory import create_component as biaffine_dep_parser from .trainable.span_classifier.factory import create_component as span_classifier from .trainable.span_linker.factory import create_component as span_linker from .trainable.embeddings.span_pooler.factory import create_component as span_pooler diff --git a/edsnlp/pipes/trainable/biaffine_dep_parser/__init__.py b/edsnlp/pipes/trainable/biaffine_dep_parser/__init__.py new file mode 100644 index 000000000..549d2fc77 --- /dev/null +++ b/edsnlp/pipes/trainable/biaffine_dep_parser/__init__.py @@ -0,0 +1 @@ +from .factory import create_component diff --git a/edsnlp/pipes/trainable/biaffine_dep_parser/biaffine_dep_parser.py b/edsnlp/pipes/trainable/biaffine_dep_parser/biaffine_dep_parser.py new file mode 100644 index 000000000..624ce34bf --- /dev/null +++ b/edsnlp/pipes/trainable/biaffine_dep_parser/biaffine_dep_parser.py @@ -0,0 +1,672 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, cast + +import foldedtensor as ft +import numpy as np +import torch +import torch.nn.functional as F +from spacy.tokens import Doc, Span +from typing_extensions import Literal + +from edsnlp.core import PipelineProtocol +from edsnlp.core.torch_component import BatchInput, BatchOutput, TorchComponent +from edsnlp.pipes.trainable.embeddings.typing import WordEmbeddingComponent +from edsnlp.utils.span_getters import SpanGetterArg, get_spans + +logger = logging.getLogger(__name__) + + +# =============================================================== +def chuliu_edmonds_one_root(scores: np.ndarray) -> np.ndarray: + """ + Shamelessly copied from + https://github.com/hopsparser/hopsparser/blob/main/hopsparser/mst.py#L63 + All credits, Loic Grobol at Université Paris Nanterre, France, the + original author of this implementation. Find the license of the hopsparser software + below: + + Copyright 2020 Benoît Crabbé benoit.crabbe@linguist.univ-paris-diderot.fr + + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in the + Software without restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the + Software, and to permit persons to whom the Software is furnished to do so, + subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, + INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A + PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + --- + + Repeatedly Use the Chu‑Liu/Edmonds algorithm to find a maximum spanning + dependency tree from the weight matrix of a rooted weighted directed graph. + + **ATTENTION: this modifies `scores` in place.** + + ## Input + + - `scores`: A 2d numeric array such that `scores[i][j]` is the weight + of the `$j→i$` edge in the graph and the 0-th node is the root. + + ## Output + + - `tree`: A 1d integer array such that `tree[i]` is the head of the `i`-th node + """ + + # FIXME: we don't actually need this in CLE: we only need one critical cycle + def tarjan(tree: np.ndarray) -> List[np.ndarray]: + """Use Tarjan's SCC algorithm to find cycles in a tree + + ## Input + + - `tree`: A 1d integer array such that `tree[i]` is the head of + the `i`-th node + + ## Output + + - `cycles`: A list of 1d bool arrays such that `cycles[i][j]` is + true iff the `j`-th node of + `tree` is in the `i`-th cycle + """ + indices = -np.ones_like(tree) + lowlinks = -np.ones_like(tree) + onstack = np.zeros_like(tree, dtype=bool) + stack = list() + # I think this is in a list to be able to mutate it in the closure, even + # though `nonlocal` exists + _index = [0] + cycles = [] + + def strong_connect(i): + _index[0] += 1 + index = _index[-1] # `_index` is of length 1 so this is also `_index[0]`??? + indices[i] = lowlinks[i] = index - 1 + stack.append(i) + onstack[i] = True + dependents = np.where(np.equal(tree, i))[0] + for j in dependents: + if indices[j] == -1: + strong_connect(j) + lowlinks[i] = min(lowlinks[i], lowlinks[j]) + elif onstack[j]: + lowlinks[i] = min(lowlinks[i], indices[j]) + + # There's a cycle! + if lowlinks[i] == indices[i]: + cycle = np.zeros_like(indices, dtype=bool) + while stack[-1] != i: + j = stack.pop() + onstack[j] = False + cycle[j] = True + stack.pop() + onstack[i] = False + cycle[i] = True + if cycle.sum() > 1: + cycles.append(cycle) + return + + # ------------------------------------------------------------- + for i in range(len(tree)): + if indices[i] == -1: + strong_connect(i) + return cycles + + # TODO: split out a `contraction` function to make this more readable + def chuliu_edmonds(scores: np.ndarray) -> np.ndarray: + """Use the Chu‑Liu/Edmonds algorithm to find a maximum spanning + arborescence from the weight matrix of a rooted weighted directed + graph + + ## Input + + - `scores`: A 2d numeric array such that `scores[i][j]` is the + weight of the `$j→i$` edge in the graph and the 0-th node is the root. + + ## Output + + - `tree`: A 1d integer array such that `tree[i]` is the head of the `i`-th node + """ + np.fill_diagonal(scores, -float("inf")) # prevent self-loops + scores[0] = -float("inf") + scores[0, 0] = 0 + tree = cast(np.ndarray, np.argmax(scores, axis=1)) + cycles = tarjan(tree) + if not cycles: + return tree + else: + # t = len(tree); c = len(cycle); n = len(noncycle) + # locations of cycle; (t) in [0,1] + cycle = cycles.pop() + # indices of cycle in original tree; (c) in t + cycle_locs = np.where(cycle)[0] + # heads of cycle in original tree; (c) in t + cycle_subtree = tree[cycle] + # scores of cycle in original tree; (c) in R + cycle_scores = scores[cycle, cycle_subtree] + # total score of cycle; () in R + total_cycle_score = cycle_scores.sum() + + # locations of noncycle; (t) in [0,1] + noncycle = np.logical_not(cycle) + # indices of noncycle in original tree; (n) in t + noncycle_locs = np.where(noncycle)[0] + + # scores of cycle's potential heads; (c x n) - (c) + () -> (n x c) in R + metanode_head_scores = ( + scores[cycle][:, noncycle] + - cycle_scores[:, np.newaxis] + + total_cycle_score + ) + # scores of cycle's potential dependents; (n x c) in R + metanode_dep_scores = scores[noncycle][:, cycle] + # best noncycle head for each cycle dependent; (n) in c + metanode_heads = np.argmax(metanode_head_scores, axis=0) + # best cycle head for each noncycle dependent; (n) in c + metanode_deps = np.argmax(metanode_dep_scores, axis=1) + + # scores of noncycle graph; (n x n) in R + subscores = scores[noncycle][:, noncycle] + # expand to make space for the metanode (n+1 x n+1) in R + subscores = np.pad(subscores, ((0, 1), (0, 1)), "constant") + # set the contracted graph scores of cycle's potential + # heads; (c x n)[:, (n) in n] in R -> (n) in R + subscores[-1, :-1] = metanode_head_scores[ + metanode_heads, np.arange(len(noncycle_locs)) + ] + # set the contracted graph scores of cycle's potential + # dependents; (n x c)[(n) in n] in R-> (n) in R + subscores[:-1, -1] = metanode_dep_scores[ + np.arange(len(noncycle_locs)), metanode_deps + ] + + # MST with contraction; (n+1) in n+1 + contracted_tree = chuliu_edmonds(subscores) + # head of the cycle; () in n + cycle_head = contracted_tree[-1] + # fixed tree: (n) in n+1 + contracted_tree = contracted_tree[:-1] + # initialize new tree; (t) in 0 + new_tree = -np.ones_like(tree) + # fixed tree with no heads coming from the cycle: (n) in [0,1] + contracted_subtree = contracted_tree < len(contracted_tree) + # add the nodes to the new tree (t)[(n)[(n) in [0,1]] in t] + # in t = (n)[(n)[(n) in [0,1]] in n] in t + new_tree[noncycle_locs[contracted_subtree]] = noncycle_locs[ + contracted_tree[contracted_subtree] + ] + # fixed tree with heads coming from the cycle: (n) in [0,1] + contracted_subtree = np.logical_not(contracted_subtree) + # add the nodes to the tree (t)[(n)[(n) in [0,1]] in t] + # in t = (c)[(n)[(n) in [0,1]] in c] in t + new_tree[noncycle_locs[contracted_subtree]] = cycle_locs[ + metanode_deps[contracted_subtree] + ] + # add the old cycle to the tree; (t)[(c) in t] in t = (t)[(c) in t] in t + new_tree[cycle_locs] = tree[cycle_locs] + # root of the cycle; (n)[() in n] in c = () in c + cycle_root = metanode_heads[cycle_head] + # add the root of the cycle to the new + # tree; (t)[(c)[() in c] in t] = (c)[() in c] + new_tree[cycle_locs[cycle_root]] = noncycle_locs[cycle_head] + return new_tree + + scores = scores.astype(np.float64) + tree = chuliu_edmonds(scores) + roots_to_try = np.where(np.equal(tree[1:], 0))[0] + 1 + + # PW: small change here (<= instead of ==) to avoid crashes in pathological cases + if len(roots_to_try) <= 1: + return tree + + # ------------------------------------------------------------- + def set_root(scores: np.ndarray, root: int) -> Tuple[np.ndarray, np.ndarray]: + """Force the `root`-th node to be the only node under the root by overwriting + the weights of the other children of the root.""" + root_score = scores[root, 0] + scores = np.array(scores) + scores[1:, 0] = -float("inf") + scores[root] = -float("inf") + scores[root, 0] = 0 + return scores, root_score + + # We find the maximum spanning dependency tree by trying every possible root + best_score, best_tree = -np.inf, None # This is what's causing it to crash + for root in roots_to_try: + _scores, root_score = set_root(scores, root) + _tree = chuliu_edmonds(_scores) + tree_probs = _scores[np.arange(len(_scores)), _tree] + tree_score = ( + (tree_probs).sum() + (root_score) + if (tree_probs > -np.inf).all() + else -np.inf + ) + if tree_score > best_score: + best_score = tree_score + best_tree = _tree + + assert best_tree is not None + return best_tree + + +class MLP(torch.nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, dropout_p: float = 0.0 + ): + super().__init__() + self.hidden = torch.nn.Linear(input_dim, hidden_dim) + self.output = torch.nn.Linear(hidden_dim, output_dim) + self.dropout = torch.nn.Dropout(dropout_p) + + def forward(self, x): + x = self.hidden(x) + x = F.relu(x) + x = self.dropout(x) + x = self.output(x) + return x + + +class BiAffine(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features, + ): + super().__init__() + self.bilinear = torch.nn.Bilinear( + in_features, in_features, out_features, bias=True + ) + self.head_linear = torch.nn.Linear(in_features, out_features, bias=False) + self.tail_linear = torch.nn.Linear(in_features, out_features, bias=False) + + def forward(self, u, v): + scores = torch.einsum("bux,bvy,lxy->buvl", u, v, self.bilinear.weight) + scores = scores + self.head_linear(u).unsqueeze(2) + scores = scores + self.tail_linear(v).unsqueeze(1) + scores = scores + self.bilinear.bias + return scores + + +class TrainableBiaffineDependencyParser( + TorchComponent[BatchOutput, BatchInput], +): + """ + The `eds.biaffine_dep_parser` component is a trainable dependency parser + based on a biaffine model ([@dozat2017deepbiaffineattentionneural]). For each + token, the model predicts a score for each possible head in the document, and a + score for each possible label for each head. The results are then decoded either + greedily by picking the best scoring head for each token independently, or + holistically by computing the Maximum Spanning Tree (MST) over the graph of + token → head scores. + + !!! warning "Experimental" + + This component is experimental. In particular, it expects the input to be + sentences and not full documents, as it has not been optimized for memory + efficiency yet and computed the full matrix of scores for all pairs of tokens + in a document. + + At the moment, it is mostly used for benchmarking and research purposes. + + Examples + -------- + ```{ .python } + import edsnlp, edsnlp.pipes as eds + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.biaffine_dep_parser( + embedding=eds.transformer(model="prajjwal1/bert-tiny"), + hidden_size=128, + dropout_p=0.1, + # labels unset, will be inferred from the data in `post_init` + decoding_mode="mst", + ), + name="dep_parser" + ) + ``` + + Dependency parsers are typically trained on CoNLL-formatted + [Universal Dependencies corpora](https://universaldependencies.org/#download), + which you can load using the [`edsnlp.data.read_conll`][edsnlp.data.read_conll] + function. + + To train the model, refer to the [Training tutorial](/tutorials/training). + + Parameters + ---------- + nlp: Optional[PipelineProtocol] + The pipeline object + name: str + Name of the component + embedding: WordEmbeddingComponent + The word embedding component + context_getter: Optional[SpanGetterArg] + What context to use when computing the span embeddings (defaults to the whole + document). For example `{"section": "conclusion"}` to predict dependencies + in the conclusion section of documents. + use_attrs: Optional[List[str]] + The attributes to use as features for the model (ex. `["pos_"]` to use the POS + tag). By default, no attributes are used. + + Note that if you train a model with attributes, you will need to provide the + same attributes during inference, and the model might not work well if the + attributes were not annotated accurately on the test data. + attr_size: int + The size of the attribute embeddings. + hidden_size: int + The size of the hidden layer in the MLP. + dropout_p: float + The dropout probability to use in the MLP. + labels: List[str] + The labels to predict. The labels can also be inferred from the data during + `nlp.post_init(...)`. + decoding_mode: Literal["greedy", "mst"] + Whether to decode the dependencies greedily or using the Maximum Spanning Tree + algorithm. + + Authors and citation + -------------------- + The `eds.biaffine_dep_parser` trainable pipe was developed by + AP-HP's Data Science team, and heavily inspired by the implementation of + [@grobol:hal-03223424]. The biaffine architecture is based on the biaffine parser + of [@dozat2017deepbiaffineattentionneural]. + """ + + def __init__( + self, + nlp: Optional[PipelineProtocol] = None, + name: str = "biaffine_dep_parser", + *, + embedding: WordEmbeddingComponent, + context_getter: Optional[SpanGetterArg] = None, + use_attrs: Optional[List[str]] = None, + attr_size: int = 32, + hidden_size: int = 128, + dropout_p: float = 0.0, + labels: List[str] = ["root"], + decoding_mode: Literal["greedy", "mst"] = "mst", + ): + super().__init__(nlp=nlp, name=name) + self.embedding = embedding + self.use_attrs: List[str] = use_attrs or [] + self.labels = list(labels) or [] + self.labels_to_idx = {label: idx for idx, label in enumerate(self.labels)} + self.context_getter = context_getter + cat_dim = self.embedding.output_size + len(self.use_attrs) * attr_size + self.head_mlp = MLP(cat_dim, hidden_size, hidden_size, dropout_p) + self.tail_mlp = MLP(cat_dim, hidden_size, hidden_size, dropout_p) + self.arc_biaffine = BiAffine(hidden_size, 1) + self.lab_biaffine = BiAffine(hidden_size, len(self.labels)) + self.root_embed = torch.nn.Parameter(torch.randn(cat_dim)[None, None, :]) + self.decoding_mode = decoding_mode + self.attr_vocabs = {attr: [] for attr in self.use_attrs} + self.attr_to_idx = {attr: {} for attr in self.use_attrs} + self.attr_embeddings = torch.nn.ModuleDict( + { + attr: torch.nn.Embedding(len(vocab), attr_size) + for attr, vocab in self.attr_vocabs.items() + } + ) + + def update_labels(self, labels: Sequence[str], attrs: Dict[str, List[str]]): + old_labs = self.labels if self.labels is not None else () + old_index = torch.as_tensor( + [i for i, lab in enumerate(old_labs) if lab in labels], dtype=torch.long + ) + new_index = torch.as_tensor( + [labels.index(lab) for lab in old_labs if lab in labels], dtype=torch.long + ) + new_biaffine = BiAffine(self.arc_biaffine.bilinear.in1_features, len(labels)) + # fmt: off + new_biaffine.bilinear.weight.data[new_index] = self.arc_biaffine.bilinear.weight.data[old_index] # noqa: E501 + new_biaffine.bilinear.bias.data[new_index] = self.arc_biaffine.bilinear.bias.data[old_index] # noqa: E501 + new_biaffine.head_linear.weight.data[new_index] = self.arc_biaffine.head_linear.weight.data[old_index] # noqa: E501 + new_biaffine.tail_linear.weight.data[new_index] = self.arc_biaffine.tail_linear.weight.data[old_index] # noqa: E501 + # fmt: on + self.lab_biaffine.bilinear.weight.data = new_biaffine.bilinear.weight.data + self.lab_biaffine.bilinear.bias.data = new_biaffine.bilinear.bias.data + self.lab_biaffine.head_linear.weight.data = new_biaffine.head_linear.weight.data + self.lab_biaffine.tail_linear.weight.data = new_biaffine.tail_linear.weight.data + self.labels = labels + self.labels_to_idx = {lab: i for i, lab in enumerate(labels)} + + for attr, vals in attrs.items(): + emb = self.attr_embeddings[attr] + old_vals = ( + self.attr_vocabs[attr] if self.attr_vocabs[attr] is not None else () + ) + old_index = torch.as_tensor( + [i for i, val in enumerate(old_vals) if val in vals], dtype=torch.long + ) + new_index = torch.as_tensor( + [vals.index(val) for val in old_vals if val in vals], dtype=torch.long + ) + new_emb = torch.nn.Embedding( + len(vals), self.attr_embeddings[attr].weight.size(1) + ) + new_emb.weight.data[new_index] = emb.weight.data[old_index] + self.attr_embeddings[attr].weight.data = new_emb.weight.data + self.attr_vocabs[attr] = vals + self.attr_to_idx[attr] = {val: i for i, val in enumerate(vals)} + + def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]): + super().post_init(gold_data, exclude=exclude) + labels = dict() + attr_vocabs = {attr: dict() for attr in self.use_attrs} + for doc in gold_data: + ctxs = ( + get_spans(doc, self.context_getter) if self.context_getter else [doc[:]] + ) + for ctx in ctxs: + for token in ctx: + labels[token.dep_] = True + for attr in self.use_attrs: + attr_vocabs[attr][getattr(token, attr)] = True + self.update_labels( + labels=list(labels.keys()), + attrs={attr: list(v.keys()) for attr, v in attr_vocabs.items()}, + ) + + def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]: + ctxs = get_spans(doc, self.context_getter) if self.context_getter else [doc[:]] + prep = { + "embedding": self.embedding.preprocess(doc, contexts=ctxs, **kwargs), + "$contexts": ctxs, + "stats": { + "dep_words": 0, + }, + } + for attr in self.use_attrs: + prep[attr] = [ + [self.attr_to_idx[attr].get(getattr(token, attr), 0) for token in ctx] + for ctx in ctxs + ] + for ctx in ctxs: + prep["stats"]["dep_words"] += len(ctx) + 1 + return prep + + def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]: + preps = self.preprocess(doc) + arc_targets = [] # head idx for each token + lab_targets = [] # arc label idx for each token + for ctx in preps["$contexts"]: + ctx_start = ctx.start + arc_targets.append( + [ + 0, + *( + token.head.i - ctx_start + 1 if token.head.i != token.i else 0 + for token in ctx + ), + ] + ) + lab_targets.append([0, *(self.labels_to_idx[t.dep_] for t in ctx)]) + return { + **preps, + "arc_targets": arc_targets, + "lab_targets": lab_targets, + } + + def collate(self, preps: Dict[str, Any]) -> BatchInput: + collated = {"embedding": self.embedding.collate(preps["embedding"])} + collated["stats"] = { + k: sum(v) for k, v in preps["stats"].items() if not k.startswith("__") + } + for attr in self.use_attrs: + collated[attr] = ft.as_folded_tensor( + preps[attr], + data_dims=("context", "tail"), + full_names=("sample", "context", "tail"), + dtype=torch.long, + ) + if "arc_targets" in preps: + collated["arc_targets"] = ft.as_folded_tensor( + preps["arc_targets"], + data_dims=("context", "tail"), + full_names=("sample", "context", "tail"), + dtype=torch.long, + ) + if "lab_targets" in preps: + collated["lab_targets"] = ft.as_folded_tensor( + preps["lab_targets"], + data_dims=("context", "tail"), + full_names=("sample", "context", "tail"), + dtype=torch.long, + ) + + return collated + + # noinspection SpellCheckingInspection + def forward(self, batch: BatchInput) -> BatchOutput: + embeds = self.embedding(batch["embedding"])["embeddings"] + embeds = embeds.refold("context", "word") + embeds = torch.cat( + [ + embeds, + *(self.attr_embeddings[attr](batch[attr]) for attr in self.use_attrs), + ], + dim=-1, + ) + + embeds_with_root = torch.cat( + [ + self.root_embed.expand(embeds.shape[0], 1, self.root_embed.size(-1)), + embeds, + ], + dim=1, + ) + + tail_embeds = self.tail_mlp(embeds_with_root) # (contexts, tail=words, dim) + head_embeds = self.head_mlp(embeds_with_root) # (contexts, words+root, dim) + + # Scores: (contexts, tail=1+words, head=1+words) + arc_logits = self.arc_biaffine(tail_embeds, head_embeds).squeeze(-1) + # Scores: (contexts, tail=1+words, head=1+words, labels) + lab_logits = self.lab_biaffine(tail_embeds, head_embeds) + + if "arc_targets" in batch: + num_labels = lab_logits.shape[-1] + arc_targets = batch["arc_targets"] # (contexts, tail=1+words) -> head_idx + lab_targets = batch["lab_targets"] + # arc_targets: (contexts, tail=1+words) -> head_idx + flat_arc_logits = ( + # (contexts, tail=1+words, head=1+words) + arc_logits.masked_fill(~arc_targets.mask[:, None, :], -10000)[ + # -> (all_flattened_tails_with_root, head=1+words) + arc_targets.mask + ] + ) + flat_arc_targets = arc_targets[arc_targets.mask] + arc_loss = ( + F.cross_entropy( + # (all_flattened_tails_with_root, head_with_root) + flat_arc_logits, + flat_arc_targets, + reduction="sum", + ) + / batch["stats"]["dep_words"] + ) + flat_lab_logits = ( + # lab_logits: (contexts, tail=1+words, head=1+words, labels) + lab_logits[arc_targets.mask] + # -> (all_flattened_tails_with_root, head=1+words, labels) + .gather(1, flat_arc_targets[:, None, None].expand(-1, 1, num_labels)) + # -> (all_flattened_tails_with_root, 1, labels) + .squeeze(1) + # -> (all_flattened_tails_with_root, labels) + ) + # TODO, directly in collate + flat_lab_targets = ( + lab_targets[ + # (contexts, tail=1+words) -> label + arc_targets.mask + ] + # (all_flattened_tails_with_root) -> label + ) + lab_loss = ( + F.cross_entropy( + flat_lab_logits, + flat_lab_targets, + reduction="sum", + ) + / batch["stats"]["dep_words"] + ) + return { + "arc_loss": arc_loss, + "lab_loss": lab_loss, + "loss": arc_loss + lab_loss, + } + else: + return { + "arc_logits": arc_logits, + "arc_labels": lab_logits.argmax(-1), + } + + def postprocess( + self, + docs: Sequence[Doc], + results: BatchOutput, + inputs: List[Dict[str, Any]], + ) -> Sequence[Doc]: + # Preprocessed docs should still be in the cache + # (context, head=words + 1, tail=words + 1), ie head -> tail + contexts = [ctx for sample in inputs for ctx in sample["$contexts"]] + + for ctx, arc_logits, arc_labels in zip( + contexts, + results["arc_logits"].detach().cpu().numpy(), + results["arc_labels"].detach().cpu(), + ): + ctx: Span + if self.decoding_mode == "greedy": + tail_to_head_idx = arc_logits.argmax(-1) + else: + tail_to_head_idx = chuliu_edmonds_one_root(arc_logits) + tail_to_head_idx = torch.as_tensor(tail_to_head_idx) + + # lab_logits: (tail=words+1, head=words+1, labels) -> prob + # arc_logits: (tail=words+1) -> head_idx + labels = arc_labels[torch.arange(arc_labels.shape[0]), tail_to_head_idx] + + # Set arc and dep rel on the Span context + for tail_idx, (head_idx, label) in enumerate( + zip(tail_to_head_idx.tolist(), labels.tolist()) + ): + if head_idx == 0: + continue + head = ctx[head_idx - 1] + tail = ctx[tail_idx - 1] + tail.head = head + tail.dep_ = self.labels[label] + + return docs diff --git a/edsnlp/pipes/trainable/biaffine_dep_parser/factory.py b/edsnlp/pipes/trainable/biaffine_dep_parser/factory.py new file mode 100644 index 000000000..f40d72a22 --- /dev/null +++ b/edsnlp/pipes/trainable/biaffine_dep_parser/factory.py @@ -0,0 +1,8 @@ +from edsnlp import registry + +from .biaffine_dep_parser import TrainableBiaffineDependencyParser + +create_component = registry.factory.register( + "eds.biaffine_dep_parser", + assigns=["token.head", "token.dep"], +)(TrainableBiaffineDependencyParser) diff --git a/edsnlp/pipes/trainable/ner_crf/ner_crf.py b/edsnlp/pipes/trainable/ner_crf/ner_crf.py index 4407f56df..5c2fda098 100644 --- a/edsnlp/pipes/trainable/ner_crf/ner_crf.py +++ b/edsnlp/pipes/trainable/ner_crf/ner_crf.py @@ -298,15 +298,6 @@ def post_init(self, docs: Iterable[Doc], exclude: Set[str]): def update_labels(self, labels: Sequence[str]): old_labels = self.labels if self.labels is not None else () - n_old = len(old_labels) - dict( - reversed( - ( - *zip(old_labels, range(n_old)), - *zip(labels, range(n_old, n_old + len(labels))), - ) - ) - ) old_index = ( torch.arange(len(old_labels) * 5) .view(-1, 5)[ diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 56ac7f0af..ec94a5bb1 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -55,6 +55,12 @@ "goal_wait": 1, "name": r"\1_\2", }, + "(.*?)/(uas|las)": { + "goal": "higher_is_better", + "format": "{:.2%}", + "goal_wait": 1, + "name": r"\1_\2", + }, } @@ -167,6 +173,12 @@ def __call__(self, nlp: Pipeline, docs: Iterable[Any]): for name, scorer in span_attr_scorers.items(): scores[name] = scorer(docs, qlf_preds) + # Custom scorers + for name, scorer in scorers.items(): + pred_docs = [d.copy() for d in tqdm(docs, desc="Copying docs")] + preds = list(nlp.pipe(tqdm(pred_docs, desc="Predicting"))) + scores[name] = scorer(docs, preds) + return scores diff --git a/mkdocs.yml b/mkdocs.yml index 786e26d55..fcc844502 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -128,6 +128,7 @@ nav: - 'NER': pipes/trainable/ner.md - 'Span Classifier': pipes/trainable/span-classifier.md - 'Span Linker': pipes/trainable/span-linker.md + - 'Biaffine Dependency Parser': pipes/trainable/biaffine-dependency-parser.md - tokenizers.md - Data Connectors: - data/index.md diff --git a/pyproject.toml b/pyproject.toml index 43043b865..855e47759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -264,30 +264,32 @@ where = ["."] # edsnlp will look both in the above dict and in the one below. [project.entry-points."edsnlp_factories"] # Trainable -"eds.transformer" = "edsnlp.pipes.trainable.embeddings.transformer.factory:create_component" -"eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component" -"eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component" -"eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" -"eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" +"eds.transformer" = "edsnlp.pipes.trainable.embeddings.transformer.factory:create_component" +"eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component" +"eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component" +"eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" +"eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" "eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" "eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" -"eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" +"eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" +"eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" [project.entry-points."spacy_scorers"] +"eds.ner_exact" = "edsnlp.metrics.ner:NerExactMetric" +"eds.ner_token" = "edsnlp.metrics.ner:NerTokenMetric" +"eds.ner_overlap" = "edsnlp.metrics.ner:NerOverlapMetric" +"eds.span_attributes" = "edsnlp.metrics.span_attributes:SpanAttributeMetric" +"eds.dep_parsing" = "edsnlp.metrics.dep_parsing:DependencyParsingMetric" + +# Deprecated "eds.ner_exact_metric" = "edsnlp.metrics.ner:NerExactMetric" "eds.ner_token_metric" = "edsnlp.metrics.ner:NerTokenMetric" "eds.ner_overlap_metric" = "edsnlp.metrics.ner:NerOverlapMetric" "eds.span_attributes_metric" = "edsnlp.metrics.span_attributes:SpanAttributeMetric" - -# Deprecated "eds.ner_exact_scorer" = "edsnlp.metrics.ner:NerExactMetric" "eds.ner_token_scorer" = "edsnlp.metrics.ner:NerTokenMetric" "eds.ner_overlap_scorer" = "edsnlp.metrics.ner:NerOverlapMetric" "eds.span_attributes_scorer" = "edsnlp.metrics.span_attributes:SpanAttributeMetric" -"eds.ner_exact" = "edsnlp.metrics.ner:NerExactMetric" -"eds.ner_token" = "edsnlp.metrics.ner:NerTokenMetric" -"eds.ner_overlap" = "edsnlp.metrics.ner:NerOverlapMetric" -"eds.span_attributes" = "edsnlp.metrics.span_attributes:SpanAttributeMetric" [project.entry-points."edsnlp_readers"] "spark" = "edsnlp.data:from_spark" diff --git a/tests/training/dep_parser_config.yml b/tests/training/dep_parser_config.yml new file mode 100644 index 000000000..19bd9b036 --- /dev/null +++ b/tests/training/dep_parser_config.yml @@ -0,0 +1,59 @@ +# 🤖 PIPELINE DEFINITION +nlp: + "@core": pipeline + + lang: fr + + components: + parser: + '@factory': eds.biaffine_dep_parser + hidden_size: 64 + decoding_mode: greedy + dropout_p: 0. + use_attrs: ['pos_'] + + embedding: + '@factory': eds.transformer + model: hf-internal-testing/tiny-bert + window: 512 + stride: 256 + +# 📈 SCORERS +scorer: + speed: false + dep: + '@metrics': "eds.dep_parsing" + +# 🎛️ OPTIMIZER +optimizer: + optim: adamw + module: ${ nlp } + total_steps: ${ train.max_steps } + groups: + ".*": + lr: 1e-3 + +# 📚 DATA +train_data: + data: + "@readers": conll + path: ./rhapsodie_sample.conllu + shuffle: dataset + batch_size: 1 docs + pipe_names: [ "parser" ] + +val_data: + "@readers": conll + path: ./rhapsodie_sample.conllu + +# 🚀 TRAIN SCRIPT OPTIONS +train: + nlp: ${ nlp } + train_data: ${ train_data } + val_data: ${ val_data } + max_steps: 20 + validation_interval: 10 + max_grad_norm: 5.0 + scorer: ${ scorer } + num_workers: 0 + optimizer: ${ optimizer } diff --git a/tests/training/test_train.py b/tests/training/test_train.py index 5b312d5f5..b7e2ac36c 100644 --- a/tests/training/test_train.py +++ b/tests/training/test_train.py @@ -2,6 +2,8 @@ import pytest +from edsnlp.metrics.dep_parsing import DependencyParsingMetric + pytest.importorskip("rich") import shutil @@ -16,7 +18,7 @@ import torch.nn from confit import Config from confit.utils.random import set_seed -from spacy.tokens import Span +from spacy.tokens import Doc, Span from edsnlp.core.registries import registry from edsnlp.data.converters import AttributesMappingArg, get_current_tokenizer @@ -134,6 +136,29 @@ def test_qualif_train(run_in_test_dir, tmp_path): assert last_scores["qual"]["micro"]["f"] >= 0.4 +def test_dep_parser_train(run_in_test_dir, tmp_path): + set_seed(42) + config = Config.from_disk("dep_parser_config.yml") + shutil.rmtree(tmp_path, ignore_errors=True) + kwargs = Config.resolve(config["train"], registry=registry, root=config) + nlp = train(**kwargs, output_dir=tmp_path, cpu=True) + scorer = GenericScorer(**kwargs["scorer"]) + val_data = list(kwargs["val_data"]) + last_scores = scorer(nlp, val_data) + + scorer_bis = GenericScorer(parser=DependencyParsingMetric(filter_expr="False")) + # Just to test what happens if the scores indicate 2 roots + val_data_bis = [Doc.from_docs([val_data[0], val_data[0]])] + nlp.pipes.parser.decoding_mode = "mst" + last_scores_bis = scorer_bis(nlp, val_data_bis) + assert last_scores_bis["parser"]["uas"] == 0.0 + + # Check empty doc + nlp("") + + assert last_scores["dep"]["las"] >= 0.4 + + def test_optimizer(): net = torch.nn.Linear(10, 10) optim = ScheduledOptimizer(