]*>([a-zA-Z0-9._-]*)<\/span>
+(?![a-zA-Z0-9._-])
+"""
+
+CITATION_RE = r"(\[@(?:[\w_:-]+)(?: *, *@(?:[\w_:-]+))*\])"
+
+
+class ClickableSnippetsPlugin(BasePlugin):
+ config_scheme: Tuple[Tuple[str, MkType]] = (
+ # ("bibtex_file", MkType(str)), # type: ignore[assignment]
+ # ("order", MkType(str, default="unsorted")), # type: ignore[assignment]
+ )
+
+ @mkdocs.plugins.event_priority(1000)
+ def on_config(self, config: MkDocsConfig):
+ for event_name, events in config.plugins.events.items():
+ for event in list(events):
+ if "autorefs" in str(event):
+ events.remove(event)
+ old_plugin = config["plugins"]["autorefs"]
+ plugin_config = dict(old_plugin.config)
+ plugin = AutorefsPlugin()
+ config.plugins["autorefs"] = plugin
+ config["plugins"]["autorefs"] = plugin
+ plugin.load_config(plugin_config)
+
+ @classmethod
+ def get_ep_namespace(cls, ep, namespace):
+ if hasattr(ep, "select"):
+ return ep.select(group=namespace)
+ else: # dict
+ return ep.get(namespace, [])
+
+ @mkdocs.plugins.event_priority(-1000)
+ def on_post_page(
+ self,
+ output: str,
+ page: mkdocs.structure.pages.Page,
+ config: mkdocs.config.Config,
+ ):
+ """
+ 1. Replace absolute paths with path relative to the rendered page
+ This must be performed after all other plugins have run.
+ 2. Replace component names with links to the component reference
+
+ Parameters
+ ----------
+ output
+ page
+ config
+
+ Returns
+ -------
+
+ """
+
+ autorefs: AutorefsPlugin = config["plugins"]["autorefs"]
+ ep = entry_points()
+ spacy_factories_entry_points = {
+ ep.name: ep.value
+ for ep in (
+ *self.get_ep_namespace(ep, "spacy_factories"),
+ *self.get_ep_namespace(ep, "edsnlp_factories"),
+ )
+ }
+
+ def replace_component(match):
+ full_group = match.group(0)
+ name = "eds." + match.group(1)
+ ep = spacy_factories_entry_points.get(name)
+ preceding = output[match.start(0) - 50 : match.start(0)]
+ if ep is not None and "DEFAULT:" not in preceding:
+ try:
+ url = autorefs.get_item_url(ep.replace(":", "."))
+ except KeyError:
+ pass
+ else:
+ return f"{name}"
+ return full_group
+
+ def replace_link(match):
+ relative_url = url = match.group(1) or match.group(2) or match.group(3)
+ page_url = os.path.join("/", page.file.url)
+ if url.startswith("/"):
+ relative_url = os.path.relpath(url, page_url)
+ return f'"{relative_url}"'
+
+ output = regex.sub(PIPE_REGEX, replace_component, output)
+ output = regex.sub(HTML_PIPE_REGEX, replace_component, output)
+ output = regex.sub(HREF_REGEX, replace_link, output)
+
+ all_snippets = ""
+ all_offsets = []
+ all_nodes = []
+
+ soups = []
+
+ # Replace absolute paths with path relative to the rendered page
+ for match in regex.finditer(".*?
", output, flags=regex.DOTALL):
+ node = match.group(0)
+ if "\n" in node:
+ soup, snippet, python_offsets, html_nodes = self.convert_html_to_code(
+ node
+ )
+ size = len(all_snippets)
+ all_snippets += snippet + "\n"
+ all_offsets.extend([size + i for i in python_offsets])
+ all_nodes.extend(html_nodes)
+ soups.append((soup, match.start(0), match.end(0)))
+
+ interpreter = jedi.Interpreter(all_snippets, [{}])
+ line_lengths = [0]
+ for line in all_snippets.split("\n"):
+ line_lengths.append(len(line) + line_lengths[-1] + 1)
+ line_lengths[-1] -= 1
+
+ # print(all_snippets)
+ # print("----")
+ for name in self.iter_names(interpreter._module_node):
+ try:
+ line, col = name.start_pos
+ offset = line_lengths[line - 1] + col
+ node_idx = bisect_right(all_offsets, offset) - 1
+
+ node = all_nodes[node_idx]
+ goto = (interpreter.goto(line, col, follow_imports=True) or [None])[0]
+ if (
+ goto
+ and goto.full_name
+ and goto.full_name.startswith("edsnlp")
+ and goto.type != "module"
+ ):
+ url = autorefs.get_item_url(goto.full_name)
+ # Check if node has no link in its upstream ancestors
+ if not node.find_parents("a"):
+ node.replace_with(
+ BeautifulSoup(
+ f'{node}',
+ "html5lib",
+ )
+ )
+ except Exception:
+ pass
+ # print("\n\n")
+
+ # Re-insert soups into the output
+ for soup, start, end in reversed(soups):
+ output = output[:start] + str(soup) + output[end:]
+
+ return output
+
+ @classmethod
+ def iter_names(cls, root):
+ if isinstance(root, parso.python.tree.Name):
+ yield root
+ for child in getattr(root, "children", ()):
+ yield from cls.iter_names(child)
+
+ @classmethod
+ def convert_html_to_code(cls, html_content: str) -> Tuple[str, list, list]:
+ pre_html_content = "" + html_content + "
"
+ soup = BeautifulSoup(pre_html_content, "html5lib")
+ code_element = soup.find("code")
+
+ line_lengths = [0]
+ for line in pre_html_content.split("\n"):
+ line_lengths.append(len(line) + line_lengths[-1] + 1)
+ line_lengths[-1] -= 1
+
+ python_code = ""
+ code_offsets = []
+ # html_offsets = [0] #
+ html_nodes = []
+ code_offset = 0
+
+ def extract_text_with_offsets(el):
+ nonlocal python_code, code_offset
+ for content in el.contents:
+ # Recursively process child elements
+ if isinstance(content, str):
+ python_code += content
+ code_offsets.append(code_offset)
+ code_offset += len(content)
+ html_nodes.append(content)
+ continue
+ extract_text_with_offsets(content)
+
+ extract_text_with_offsets(code_element)
+ # html_offsets = html_offsets[1:]
+
+ return soup, python_code, code_offsets, html_nodes
+
+ # print("\nOffset Mapping (Python Index -> HTML Index):")
+ # for mapping in offset_mapping:
+ # print(mapping)
diff --git a/docs/scripts/plugin.py b/docs/scripts/plugin.py
index 13f65ecb7..b34dbb786 100644
--- a/docs/scripts/plugin.py
+++ b/docs/scripts/plugin.py
@@ -1,21 +1,14 @@
import os
from pathlib import Path
+import jedi
import mkdocs.config
import mkdocs.plugins
import mkdocs.structure
import mkdocs.structure.files
import mkdocs.structure.nav
import mkdocs.structure.pages
-import regex
-from mkdocs.config.defaults import MkDocsConfig
-
-from docs.scripts.autorefs.plugin import AutorefsPlugin
-
-try:
- from importlib.metadata import entry_points
-except ImportError:
- from importlib_metadata import entry_points
+from bs4 import BeautifulSoup
def exclude_file(name):
@@ -33,21 +26,6 @@ def exclude_file(name):
"""
-@mkdocs.plugins.event_priority(1000)
-def on_config(config: MkDocsConfig):
- for event_name, events in config.plugins.events.items():
- for event in list(events):
- if "autorefs" in str(event):
- print("REMOVING EVENT", event_name, event)
- events.remove(event)
- old_plugin = config["plugins"]["autorefs"]
- plugin_config = dict(old_plugin.config)
- plugin = AutorefsPlugin()
- config.plugins["autorefs"] = plugin
- config["plugins"]["autorefs"] = plugin
- plugin.load_config(plugin_config)
-
-
def on_files(files: mkdocs.structure.files.Files, config: mkdocs.config.Config):
"""
Recursively the navigation of the mkdocs config
@@ -147,85 +125,41 @@ def on_page_read_source(page, config):
return None
-HREF_REGEX = (
- r"(?<=<\s*(?:a[^>]*href|img[^>]*src)=)"
- r'(?:"([^"]*)"|\'([^\']*)|[ ]*([^ =>]*)(?![a-z]+=))'
-)
-# Maybe find something less specific ?
-PIPE_REGEX = r"(?]*>eds<\/span>
-]*>[.]<\/span>
-]*>([a-zA-Z0-9._-]*)<\/span>
-(?![a-zA-Z0-9._-])
-"""
-
-
-def get_ep_namespace(ep, namespace):
- if hasattr(ep, "select"):
- return ep.select(group=namespace)
- else: # dict
- return ep.get(namespace, [])
+# Get current git commit
+GIT_COMMIT = os.popen("git rev-parse --short HEAD").read().strip()
-@mkdocs.plugins.event_priority(-1000)
+@mkdocs.plugins.event_priority(-2000)
def on_post_page(
output: str,
page: mkdocs.structure.pages.Page,
config: mkdocs.config.Config,
):
"""
- 1. Replace absolute paths with path relative to the rendered page
- This must be performed after all other plugins have run.
- 2. Replace component names with links to the component reference
-
- Parameters
- ----------
- output
- page
- config
-
- Returns
- -------
-
+ Add github links to the html output
"""
-
- autorefs: AutorefsPlugin = config["plugins"]["autorefs"]
- ep = entry_points()
- spacy_factories_entry_points = {
- ep.name: ep.value
- for ep in (
- *get_ep_namespace(ep, "spacy_factories"),
- *get_ep_namespace(ep, "edsnlp_factories"),
- )
- }
-
- def replace_component(match):
- full_group = match.group(0)
- name = "eds." + match.group(1)
- ep = spacy_factories_entry_points.get(name)
- preceding = output[match.start(0) - 50 : match.start(0)]
- if ep is not None and "DEFAULT:" not in preceding:
- try:
- url = autorefs.get_item_url(ep.replace(":", "."))
- except KeyError:
- pass
- else:
- return f"{name}"
- return full_group
-
- def replace_link(match):
- relative_url = url = match.group(1) or match.group(2) or match.group(3)
- page_url = os.path.join("/", page.file.url)
- if url.startswith("/"):
- relative_url = os.path.relpath(url, page_url)
- return f'"{relative_url}"'
-
- # Replace absolute paths with path relative to the rendered page
- output = regex.sub(PIPE_REGEX, replace_component, output)
- output = regex.sub(HTML_PIPE_REGEX, replace_component, output)
- output = regex.sub(HREF_REGEX, replace_link, output)
-
- return output
+ # Find all the headings (h1, h2, ...) whose id starts with "edsnlp"
+ soup = BeautifulSoup(output, "html.parser")
+ for heading in soup.find_all(["h1", "h2", "h3", "h4", "h5", "h6"]):
+ ref = heading.get("id", "")
+ if ref.startswith("edsnlp.") and "--" not in ref:
+ code = "import edsnlp; " + ref
+ interpreter = jedi.Interpreter(code, namespaces=[{}])
+ goto = interpreter.goto(1, len(code), follow_imports=True)
+ if not goto:
+ print("Could not get source for", ref)
+ continue
+ file = goto[0].module_path.relative_to(Path.cwd())
+ line = goto[0].line
+ # Add a "[source]" span with a link to the source code in a new tab
+ url = f"https://github.com/aphp/edsnlp/blob/{GIT_COMMIT}/{file}#L{line}"
+ heading.append(
+ BeautifulSoup(
+ f''
+ f'[source]',
+ features="html.parser",
+ )
+ )
+ # add "sourced-heading" to heading class
+ heading["class"] = heading.get("class", []) + ["sourced-heading"]
+ return str(soup)
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index 8d5c91093..e51e864ba 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -83,20 +83,21 @@ We provide step-by-step guides to get you started. We cover the following use-ca
---
Quickly visualize the results of your pipeline as annotations or tables.
-=== card {: href=/tutorials/training }
+=== card {: href=/tutorials/make-a-training-script }
- :fontawesome-solid-brain:
- **Training**
+ :fontawesome-solid-flask:
+ **Deep learning tutorial**
---
- Learn how to train a deep-learning model, focusing on named entity recognition.
+ Learn how EDS-NLP handles training deep-neural networks.
-=== card {: href=/tutorials/make-a-training-script }
+=== card {: href=/tutorials/training }
- :fontawesome-solid-flask:
- **Make a custom training script**
+ :fontawesome-solid-brain:
+ **Training API**
---
- Learn how to write a custom deep-learning training script.
+ Learn how to quicky train a deep-learning model with `edsnlp.train`.
+
diff --git a/docs/tutorials/make-a-training-script.md b/docs/tutorials/make-a-training-script.md
index 82bfe7615..806122872 100644
--- a/docs/tutorials/make-a-training-script.md
+++ b/docs/tutorials/make-a-training-script.md
@@ -1,7 +1,15 @@
-# Custom training script
+# Deep-learning tutorial
In this tutorial, we'll see how we can write our own deep learning model training script with EDS-NLP. We will implement a script to train a named-entity recognition (NER) model.
+If you do not care about the details and just want to train a model, we suggest you to use the [training API](/tutorials/training) and move on to the next tutorial.
+
+!!! warning "Hardware requirements"
+
+ Training a modern deep learning model requires a lot of computational resources. We recommend using a machine with a GPU, ideally with at least 16GB of VRAM. If you don't have access to a GPU, you can use a cloud service like [Google Colab](https://colab.research.google.com/), [Kaggle](https://www.kaggle.com/), [Paperspace](https://www.paperspace.com/) or [Vast.ai](https://vast.ai/).
+
+Under the hood, EDS-NLP uses PyTorch to train deep-learning models. EDS-NLP acts as a sidekick to PyTorch, providing a set of tools to perform preprocessing, composition and evaluation. The trainable [`TorchComponents`][edsnlp.core.torch_component.TorchComponent] are actually PyTorch modules with a few extra methods to handle the feature preprocessing and postprocessing. Therefore, EDS-NLP is fully compatible with the PyTorch ecosystem.
+
## Step-by-step walkthrough
Training a supervised deep-learning model consists in feeding batches of annotated samples taken from a training corpus to a model and optimizing its parameters of the model to decrease its prediction
@@ -9,8 +17,8 @@ error. The process of training a pipeline with EDS-NLP is structured as follows:
### 1. Defining the model
-We first start by seeding the random states and instantiating a new trainable pipeline. The model described here computes text embeddings with a pre-trained transformer followed by a CNN, and performs
-the NER prediction task using a Conditional Random Field (CRF) token classifier. To compose deep-learning modules, we simply compose them using the `eds.___` factories.
+We first start by seeding the random states and instantiating a new trainable pipeline composed of [trainable pipes](/pipes/trainable). The model described here computes text embeddings with a pre-trained transformer followed by a CNN, and performs
+the NER prediction task using a Conditional Random Field (CRF) token classifier.
```python
import edsnlp, edsnlp.pipes as eds
@@ -22,7 +30,7 @@ nlp = edsnlp.blank("eds")
nlp.add_pipe(
eds.ner_crf( # (1)!
mode="joint", # (2)!
- target_span_getter="ml-ner", # (3)!
+ target_span_getter="gold-ner", # (3)!
window=20,
embedding=eds.text_cnn( # (4)!
kernel_sizes=[3],
@@ -39,18 +47,16 @@ nlp.add_pipe(
1. We use the `eds.ner_crf` NER task module, which classifies word embeddings into NER labels (BIOUL scheme) using a CRF.
2. Each component of the pipeline can be configured with a dictionary, using the parameter described in the component's page.
-3. The `target_span_getter` parameter defines the name of the span group used to train the NER model. We will need to make sure the entities from the training dataset are assigned to this span group (next section).
+3. The `target_span_getter` parameter defines the name of the span group used to train the NER model. In this case, the model will look for the entities to train on in `doc.spans["gold-ner"]`. This is important because we might store entities in other span groups with a different purpose (e.g. `doc.spans["sections"]` contain the sections Spans, but we don't want to train on these). We will need to make sure the entities from the training dataset are assigned to this span group (next section).
4. The word embeddings used by the CRF are computed by a CNN, which builds on top of another embedding layer.
5. The base embedding layer is a pretrained transformer, which computes contextualized word embeddings.
6. We chose the `prajjwal1/bert-tiny` model in this tutorial for testing purposes, but we recommend using a larger model like `bert-base-cased` or `camembert-base` (French) for real-world applications.
-### 2. Adapting a dataset
+### 2. Loading the raw dataset and convert it into Doc objects
-To train a pipeline, we must convert our annotated data into documents that will be either used as training samples or a evaluation samples. This is done by applying to function to the raw data to convert entries
-into a list of Doc objects. We will assume the dataset has been annotated using [Brat](https://brat.nlplab.org), but any format can be used.
+To train a pipeline, we must convert our annotated data into `Doc` objects that will be either used as training samples or evaluation samples. We will assume the dataset is in [Standoff format](/data/standoff), usually produced by the [Brat](https://brat.nlplab.org) annotation tool, but any format can be used.
-At this step, we might also want to perform data augmentation, filtering, splitting or any other data transformation. Note that this function will be used to load both the training data and the test
-data. Here we will split on line jumps and filter out empty documents from the training data.
+At this step, we might also want to perform data augmentation, filtering, splitting or any other data transformation. In this tutorial, we will split on line jumps and filter out empty documents from the training data. We will use our [Stream][edsnlp.core.stream.Stream] API to handle the data processing, but you can use any method you like, so long as you end up with a collection of `Doc` objects.
```{ .python .no-check }
import edsnlp
@@ -63,100 +69,105 @@ def skip_empty_docs(batch):
training_data = (
- # Read the data from the brat directory and convert it into Docs,
- edsnlp.data.read_standoff(
+ edsnlp.data.read_standoff( # (1)!
train_data_path,
- # Store spans in default "ents", and "ml-ner" for the training (prev. section)
- span_setter=["ents", "ml-ner"],
- # Tokenize the training docs with the same tokenizer as the trained model
- tokenizer=nlp.tokenizer,
+ tokenizer=nlp.tokenizer, # (2)!
+ span_setter=["ents", "gold-ner"], # (3)!
)
- # Split the documents on line jumps
- .map(eds.split(regex="\n\n"))
- # Filter out empty documents
- .map_batches(skip_empty_docs)
- # Add any other transformation if needed
+ .map(eds.split(regex="\n\n")) # (4)!
+ .map_batches(skip_empty_docs) # (5)!
)
```
-However, we will keep all the documents in the validation data, even empty docs, to obtain representative metrics.
+1. Read the data from the brat directory and convert it into Docs.
+2. Tokenize the training docs with the same tokenizer as the trained model
+3. Store the annotated Brat entities as spans in `doc.ents`, and `doc.spans["gold-ner"]`
+4. Split the documents on line jumps.
+5. Filter out empty documents.
+
+As for the validation data, we will keep all the documents, even empty ones, to obtain representative metrics.
```{ .python .no-check }
val_data = edsnlp.data.read_standoff(
val_data_path,
- span_setter=["ents", "ml-ner"],
tokenizer=nlp.tokenizer,
+ span_setter=["ents", "gold-ner"],
)
-val_docs = list(val_data) # execute and convert the stream to a list
+val_docs = list(val_data) # (1)!
```
-### 4. Complete the initialization with the training data
+1. Cache the stream result into a list of `Doc`
+
+### 3. Complete the initialization of the model
-We initialize the missing or incomplete components attributes (such as label vocabularies) with the training dataset
+We initialize the missing or incomplete components attributes (such as label vocabularies) with the training dataset. Indeed, when defining the model, we specified the architecture of the model, but we did not specify the types of named entities that the model will predict. This can be done either
+
+- explicitly by setting the `labels` parameter in `eds.ner_crf` in the [definition](#1-defining-the-model) above,
+- automatically with `post_init`: then `eds.ner_crf` looks in `doc.spans[target_span_getter]` of all docs in `training_data` to infer the labels.
```{ .python .no-check }
nlp.post_init(training_data)
```
-### 5. Preprocessing the data
+### 4. Making the stream of mini-batches
-The training dataset is then preprocessed into features. The resulting preprocessed dataset is then wrapped into a pytorch DataLoader to be fed to the model during the training loop with the model's own collate method. We will use EDS-NLP's [Streams][edsnlp.core.stream.Stream] to handle the data processing.
+The training dataset of `Doc` objects is then preprocessed into features to be fed to the model during the training loop. We will continue to use EDS-NLP's streams to handle the data processing :
-Loop on the training data (same as `loop=True` in the `read_standoff` method). Note that this will
-loop before shuffling or any further preprocessing step, meaning these operations will be applied every epoch. This is usually a good thing if preprocessing contains randomness to increase the diversity of
-the training samples while avoiding loading multiple versions of a same document in memory. To loop after preprocessing, we can collect the stream into a list and loop on the list (`edsnlp.data.from_iterable(training_data), loop=True`).
+- We first request the training data stream to loop on the input data, since we want that each example is seen multiple times during the training until a given number of steps is reached
-```{ .python .no-check }
-batches = training_data.loop()
-```
+ ??? note "Looping in EDS-NLP Streams"
-Apply shuffling to our stream. If our dataset is too large to fit in memory, instead of "dataset" we can set the shuffle batch size to "100 docs" for example, or "fragment" for parquet datasets.
+ Note that in EDS-NLP, looping on a stream is always done on the input data, no matter when `loop()` is called. This means that shuffling or any further preprocessing step will be applied multiple times, each time we loop. This is usually a good thing if preprocessing contains randomness to increase the diversity of the training samples while avoiding loading multiple versions of a same document in memory. To loop after preprocessing, we can collect the stream into a list and loop on the list (`edsnlp.data.from_iterable(list(training_data)), loop=True`).
-```{ .python .no-check }
-batches = batches.shuffle("dataset")
-```
+- We shuffle the data before batching to diversify the samples in each mini-batch
+- We extract the features and labels required by each component (and sub-components) of the pipeline
+- Finally, we group the samples into mini-batches, such that each mini-batch contains a maximum number of tokens, or any other batching criterion and assemble (or "collate") the features into tensors
```{ .python .no-check }
-# We can now preprocess the data
-batches = batches.map(
- nlp.preprocess, # (1)!
- kwargs={"supervision": True}
+from edsnlp.utils.batching import stat_batchify
+
+device = "cuda" if torch.cuda.is_available() else "cpu" # (1)!
+batches = (
+ training_data.loop()
+ .shuffle("dataset") # (2)!
+ .map(nlp.preprocess, kwargs={"supervision": True}) # (3)!
+ .batchify(batch_size=32 * 128, batch_by=stat_batchify("tokens")) # (4)!
+ .map(nlp.collate, kwargs={"device": device})
)
```
-1. This will call the `preprocess_supervised` method of the [TorchComponent][edsnlp.core.torch_component.TorchComponent] class and return a nested dictionary containing the required features and labels.
-
-Make batches of at most 8192 tokens and assemble (or "collate") the samples into a batch
-
-```{ .python .no-check }
-from edsnlp.utils.batching import stat_batchify
-batches = batches.batchify(batch_size=8192, batch_by=stat_batchify("tokens") # (1)!
-batches = batches.map(nlp.collate, kwargs={"device": device})
-```
+1. Check if a GPU is available and set the device accordingly.
+2. Apply shuffling to our stream. If our dataset is too large to fit in memory, instead of "dataset" we can set the shuffle batch size to "100 docs" for example, or "fragment" for parquet datasets.
+3. This will call the `preprocess_supervised` method of the [TorchComponent][edsnlp.core.torch_component.TorchComponent] class and return a nested dictionary containing the required features and labels.
+4. Make batches that contain at most 32 * 128 tokens (e.g. 32 samples of 128 tokens, but this accounts samples may have different lengths). We use the `stat_batchify` function to look for a key containing `tokens` in the features `stats` sub-dictionary and add samples to the batch until the sum of the `*tokens*` stats exceeds 32 * 128.
-1. We must make sure that a feature produced by `preprocess` contains the string "tokens".
-and that's it ! We now have a looping stream of batches that we can feed to our model.
-For better efficiency, we can also perform the preprocessing step in parallel in a separate worker by using `num_cpu_workers` option on our stream.
+and that's it ! We now have a looping stream of mini-batches that we can feed to our model.
+For better efficiency, we can also perform this in parallel in a separate worker by setting `num_cpu_workers` to 1 or more.
+Note that streams in EDS-NLP are lazy, meaning that the execution has not started yet, and the data is not loaded in memory. This will only happen when we start iterating over the stream in the next section.
```{ .python .no-check }
-batches = batches.set_processing(num_cpu_workers=1, process_start_method="spawn") # (1)!
+batches = batches.set_processing(
+ num_cpu_workers=1,
+ process_start_method="spawn" # (1)!
+)
```
1. Since we use a GPU, we must use the "spawn" method to create the workers. This is because the default multiprocessing "fork" method is not compatible with CUDA.
-### 6. The training loop
+### 5. The training loop
-We instantiate an optimizer and start the training loop
+We instantiate a pytorch optimizer and start the training loop
```{ .python .no-check }
from itertools import chain, repeat
from tqdm import tqdm
+import torch
lr = 3e-4
max_steps = 400
-# Move the model to the GPU if available (device = "cuda")
+# Move the model to the GPU
nlp.to(device)
optimizer = torch.optim.AdamW(
@@ -171,18 +182,15 @@ for step in tqdm(range(max_steps), "Training model", leave=True):
optimizer.zero_grad()
```
-### 7. Optimizing the weights
+### 6. Optimizing the weights
-Inside the training loop, the trainable components are fed the collated batches from the dataloader by calling
-the [`TorchComponent.forward`][edsnlp.core.torch_component.TorchComponent.forward] method (via a simple call) to compute the losses. In the case we train a multi-task model (not in this tutorial), the
-outputs of shared embedding are reused between components, we enable caching by wrapping this step in a cache context. The training loop is otherwise carried in a similar fashion to a standard pytorch
-training loop
+Inside the training loop, the trainable components are fed the collated batches from the dataloader by calling the [`TorchComponent.forward`][edsnlp.core.torch_component.TorchComponent.forward] method (via a simple call) to compute the losses. In the case we train a multitask model (not in this tutorial) and the outputs of a shared embedding are reused between components, we enable caching by wrapping this step in a cache context. The training loop is otherwise carried in a similar fashion to a standard pytorch training loop.
```{ .python .no-check }
with nlp.cache():
loss = torch.zeros((), device=device)
for name, component in nlp.torch_components():
- output = component(batch[name]) # (1)!
+ output = component(batch[name])
if "loss" in output:
loss += output["loss"]
@@ -191,9 +199,9 @@ training loop
optimizer.step()
```
-### 8. Evaluating the model
+### 7. Evaluating the model
-Finally, the model is evaluated on the validation dataset and saved at regular intervals.
+Finally, the model is evaluated on the validation dataset and saved at regular intervals. We will use the `NerExactMetric` to evaluate the NER performance using Precision, Recall and F1 scores. This metric only counts an entity as correct if it matches the label and boundaries of a target entity.
```{ .python .no-check }
from edsnlp.metrics.ner import NerExactMetric
@@ -202,22 +210,20 @@ from copy import deepcopy
metric = NerExactMetric(span_getter=nlp.pipes.ner.target_span_getter)
...
-
if ((step + 1) % 100) == 0:
with nlp.select_pipes(enable=["ner"]): # (1)!
preds = deepcopy(val_docs)
-
- # Clean the documents that our model will annotate
for doc in preds:
- doc.ents = doc.spans["ml-ner"] = []
- preds = nlp.pipe(preds) # (2)!
+ doc.ents = doc.spans["gold-ner"] = [] # (2)!
+ preds = nlp.pipe(preds) # (3)!
print(metric(val_docs, preds))
- nlp.to_disk("model") # (3)!
+ nlp.to_disk("model") #(4)!
```
1. In the case we have multiple pipes in our model, we may want to selectively evaluate each pipe, thus we use the `select_pipes` method to disable every pipe except "ner".
-2. We use the `pipe` method to run the "ner" component on the validation dataset. This method is similar to the `__call__` method of EDS-NLP components, but it is used to run a component on a list of
+2. Clean the documents that our model will annotate
+3. We use the `pipe` method to run the "ner" component on the validation dataset. This method is similar to the `__call__` method of EDS-NLP components, but it is used to run a component on a list of
Docs. This is also equivalent to
```{ .python .no-check }
preds = (
@@ -226,7 +232,7 @@ metric = NerExactMetric(span_getter=nlp.pipes.ner.target_span_getter)
.map_pipeline(nlp)
)
```
-3. We could also have saved the model with `torch.save(model, "model.pt")`, but `nlp.to_disk` avoids pickling and allows to inspect the model's files by saving them into a structured directory.
+4. We could also have saved the model with `torch.save(model, "model.pt")`, but `nlp.to_disk` avoids pickling and allows to inspect the model's files by saving them into a structured directory.
## Full example
@@ -255,13 +261,13 @@ Let's wrap the training code in a function, and make it callable from the comman
nlp: edsnlp.Pipeline,
train_data_path: str,
val_data_path: str,
- batch_size: int = 8192,
+ batch_size: int = 32 * 128,
lr: float = 1e-4,
max_steps: int = 400,
num_preprocessing_workers: int = 1,
evaluation_interval: int = 100,
):
- device = "cuda"
+ device = "cuda" if torch.cuda.is_available() else "cpu"
# Define function to skip empty docs
def skip_empty_docs(batch: Iterator) -> Iterator:
@@ -273,7 +279,7 @@ Let's wrap the training code in a function, and make it callable from the comman
training_data = (
edsnlp.data.read_standoff(
train_data_path,
- span_setter=["ents", "ml-ner"],
+ span_setter=["ents", "gold-ner"],
tokenizer=nlp.tokenizer,
)
.map(eds.split(regex="\n\n"))
@@ -283,7 +289,7 @@ Let's wrap the training code in a function, and make it callable from the comman
# Load validation data
val_data = edsnlp.data.read_standoff(
val_data_path,
- span_setter=["ents", "ml-ner"],
+ span_setter=["ents", "gold-ner"],
tokenizer=nlp.tokenizer,
)
val_docs = list(val_data)
@@ -331,7 +337,7 @@ Let's wrap the training code in a function, and make it callable from the comman
# Clean the documents that our model will annotate
preds = deepcopy(val_docs)
for doc in preds:
- doc.ents = doc.spans["ml-ner"] = []
+ doc.ents = doc.spans["gold-ner"] = []
preds = nlp.pipe(preds)
print(metric(val_docs, preds))
@@ -343,7 +349,7 @@ Let's wrap the training code in a function, and make it callable from the comman
nlp.add_pipe(
eds.ner_crf(
mode="joint",
- target_span_getter="ml-ner",
+ target_span_getter="gold-ner",
window=20,
embedding=eds.text_cnn(
kernel_sizes=[3],
@@ -360,7 +366,7 @@ Let's wrap the training code in a function, and make it callable from the comman
nlp,
train_data_path="my_brat_data/train",
val_data_path="my_brat_data/val",
- batch_size=8192,
+ batch_size=32 * 128,
lr=1e-4,
max_steps=1000,
num_preprocessing_workers=1,
@@ -394,7 +400,7 @@ nlp:
ner:
"@factory": "eds.ner_crf"
mode: "joint"
- target_span_getter: "ml-ner"
+ target_span_getter: "gold-ner"
window: 20
embedding:
@@ -411,7 +417,7 @@ train:
nlp: ${ nlp }
train_data_path: my_brat_data/train
val_data_path: my_brat_data/val
- batch_size: 8192
+ batch_size: ${ 32 * 128 }
lr: 1e-4
max_steps: 400
num_preprocessing_workers: 1
@@ -433,4 +439,8 @@ python train.py --config config.cfg --nlp.components.ner.embedding.embedding.tra
## Going further
-This tutorial gave you a glimpse of the training API of EDS-NLP. To build a custom trainable component, you can refer to the [TorchComponent][edsnlp.core.torch_component.TorchComponent] class or look up the implementation of some of the trainable components on GitHub.
+EDS-NLP also provides a generic training script that follows the same structure as the one we just wrote. You can learn more about in the [next Training API tutorial](/tutorials/training).
+
+This tutorial gave you a glimpse of the training API of EDS-NLP. To build a custom trainable component, you can refer to the [TorchComponent][edsnlp.core.torch_component.TorchComponent] class or look up the implementation of [some of the trainable components on GitHub](https://github.com/aphp/edsnlp/tree/master/edsnlp/pipes/trainable).
+
+We also recommend looking at an existing project as a reference, such as [eds-pseudo](https://github.com/aphp/eds-pseudo) or [mlg-norm](https://github.com/percevalw/mlg-norm).
diff --git a/docs/tutorials/matching-a-terminology.md b/docs/tutorials/matching-a-terminology.md
index fc8ee2d6c..ea7cbd74d 100644
--- a/docs/tutorials/matching-a-terminology.md
+++ b/docs/tutorials/matching-a-terminology.md
@@ -15,7 +15,7 @@ You should consider reading the [matcher's specific documentation](../pipes/core
EDS-NLP's [`RegexMatcher`][edsnlp.matchers.regex.RegexMatcher] lets the user match entire expressions using regular expressions. To achieve this, the matcher has to get to the text representation, match on it, and get back to spaCy's abstraction.
- The `EDSPhraseMatcher` lets EDS-NLP reuse spaCy's efficient algorithm, while adding the ability to skip pollution tokens (see the [normalisation documentation](../pipes/core/normalisation.md) for detail)
+ The `EDSPhraseMatcher` lets EDS-NLP reuse spaCy's efficient algorithm, while adding the ability to skip pollution tokens (see the [normalizer documentation](../pipes/core/normalizer.md) for detail)
## A simple use case : finding COVID19
@@ -135,7 +135,7 @@ It handles:
You can activate it like any other component.
-```python hl_lines="4 10 17 23 24"
+```python hl_lines="4 10 17 22 23"
import edsnlp, edsnlp.pipes as eds
text = (
@@ -197,7 +197,7 @@ Of course, we _could_ write out every imaginable possibility, but this will quic
## Using regular expressions
-Let us redefine the pipeline once again, this time using regular expressions:
+Let us redefine the pipeline once again, this time using regular expressions. Using regular expressions can help define richer patterns using more compact queries.
```python
import edsnlp, edsnlp.pipes as eds
@@ -233,48 +233,4 @@ doc.ents
2. We can mix and match patterns! Here we keep looking for patients using spaCy's term matching.
3. RegExp matching is not limited to the verbatim text! You can choose to use one of spaCy's native attribute, ignore excluded tokens, etc.
-This code is complete, and should run as is.
-
-Using regular expressions can help define richer patterns using more compact queries.
-
-## Visualising matched entities
-
-EDS-NLP is part of the spaCy ecosystem, which means we can benefit from spaCy helper functions.
-For instance, spaCy's visualiser displacy can let us visualise the matched entities:
-
-```python
-# ↑ Omitted code above ↑
-
-from spacy import displacy
-
-colors = {
- "covid": "orange",
- "respiratoire": "steelblue",
-}
-options = {
- "colors": colors,
-}
-
-displacy.render(doc, style="ent", options=options)
-```
-
-If you run this within a notebook, you should get:
-
-
-
Motif de prise en charge : probable pneumopathie a
-
- COVID19
- covid
-
- , sans difficultés
-
- respiratoires
- respiratoire
-
- Le père du patient est
-
- asthmatique
- respiratoire
-
- .
-
+To visualize extracted entities, check out the [Visualization](/tutorials/visualization) tutorial.
diff --git a/docs/tutorials/multiple-texts.md b/docs/tutorials/multiple-texts.md
index 2d6a2bcc8..f22197da1 100644
--- a/docs/tutorials/multiple-texts.md
+++ b/docs/tutorials/multiple-texts.md
@@ -102,17 +102,23 @@ There are a few issues with this approach:
To efficiently perform the same operations on multiple documents at once, EDS-NLP uses [streams][edsnlp.core.stream.Stream], which record the operations to perform on the
documents without actually executing them directly, similar to the way Spark does, or polars with its LazyFrame.
-This allows EDS-NLP to distribute these operations on multiple cores or machines when it is time to execute them. We can configure how the collection operations are run (how many jobs/workers, how
-many gpus, whether to use the spark engine) via the stream [`.set_processing(...)`][edsnlp.core.stream.Stream.set_processing] method.
+This allows EDS-NLP to distribute these operations on multiple cores or machines when it is time to execute them. We can configure how the collection operations are run (how many jobs/workers, how many gpus, whether to use the spark engine) via the stream [`set_processing()`][edsnlp.core.stream.Stream.set_processing] method.
For instance,
```python
docs = edsnlp.data.from_iterable(corpus)
-print(docs)
-#
+print(docs) # (1)!
```
+1. Printed version of the stream:
+ ```
+ Stream(
+ reader=IterableReader(data=),
+ ops=[],
+ writer=None)
+ ```
+
as well as any `edsnlp.data.read_*` or `edsnlp.data.from_*` return a stream, that we can iterate over or complete with more operations. To apply the model on our collection of documents, we
can simply do:
@@ -120,8 +126,28 @@ can simply do:
docs = docs.map_pipeline(nlp)
# or à la spaCy :
# docs = nlp.pipe(docs)
+print(docs) # (1)!
```
+1. Printed version of the stream:
+ ```
+ Stream(
+ reader=IterableReader(data=),
+ ops=[
+ map(_ensure_doc[]),
+ batchify(size=None, fn=None, sentinel_mode=None),
+ map_batches_op(),
+ map_batches_op(),
+ map_batches_op(),
+ map_batches_op(),
+ map_batches_op(),
+ map_batches_op(),
+ map_batches_op(),
+ unbatchify()
+ ],
+ writer=None)
+ ```
+
??? warning "SpaCy vs EDS-NLP"
SpaCy's `nlp.pipe` method is not the same as EDS-NLP's `nlp.pipe` method, and will iterate over anything you pass to it, therefore executing the operations scheduled in our stream.
@@ -174,13 +200,13 @@ df = docs.to_pandas(converter=convert_doc_to_rows)
df = docs.to_pandas(
converter="ents",
span_getter=["ents", "dates"],
- span_attributes={
+ span_attributes=[
# span._.*** name: column name
- "negation": "negation",
- "hypothesis": "hypothesis",
- "family": "family",
- "date.datetime": "datetime",
- },
+ "negation",
+ "hypothesis",
+ "family",
+ "date.datetime",
+ ],
)
```
@@ -270,22 +296,22 @@ note_nlp = docs.to_pandas(
converter="ents",
# Below are the arguments to the converter
span_getter=["ents", "dates"],
- span_attributes={ # (1)
+ span_attributes=[ # (1)
# span._.*** name: column name
- "negation": "negation",
- "hypothesis": "hypothesis",
- "family": "family",
- "date.datetime": "datetime",
+ "negation",
+ "hypothesis",
+ "family",
+ "date.datetime",
# having individual columns for each date part
# can be useful for incomplete dates (eg, "in May")
- "date.day": "date_day",
- "date.month": "date_month",
- "date.year": "date_year",
- },
+ "date.day",
+ "date.month",
+ "date.year",
+ ],
)
```
-1. You can just pass a list if you don't want to rename the attributes.
+1. You can just pass a dict if you want to explicitely rename the attributes.
The result on the first note:
@@ -297,10 +323,6 @@ The result on the first note:
### Locally, using multiple parallel workers
-!!! warning "Caveat"
-
- Since workers can produce their results in any order, the order of the rows in the resulting DataFrame may not be the same as the order of the input data.
-
```{ .python hl_lines="8" }
# Read from a dataframe & use the omop converter
docs = edsnlp.data.from_pandas(data, converter="omop")
@@ -316,50 +338,65 @@ docs = docs.set_processing(backend="multiprocessing")
note_nlp = docs.to_pandas(
converter="ents",
span_getter=["ents", "dates"],
- span_attributes={
- "negation": "negation",
- "hypothesis": "hypothesis",
- "family": "family",
- "date.datetime": "datetime",
+ span_attributes=[
+ "negation",
+ "hypothesis",
+ "family",
+ "date.datetime",
# having individual columns for each date part
# can be useful for incomplete dates (eg, "in May")
- "date.day": "date_day",
- "date.month": "date_month",
- "date.year": "date_year",
- },
+ "date.day",
+ "date.month",
+ "date.year",
+ ],
)
```
+!!! note "Deterministic processing"
+
+ By default, from version 0.14.0, EDS-NLP dispatches tasks to workers in a round-robin fashion to ensure deterministic processing. This mechanism can be disabled to send documents to workers as soon as they are available, which may result in faster processing but out-of-order results.
+
+ To disable processing determinism, use `set_processing(deterministic=False)`. Note that this parameter is only used when using the `multiprocessing` backend.
+
### In a distributed fashion with spark
To use the Spark engine to distribute the computation, we create our stream from the Spark dataframe directly and write the result to a new Spark dataframe. EDS-NLP will automatically
distribute the operations on the cluster (setting `backend="spark"` behind the scenes), but you can change the backend (for instance to `multiprocessing` to run locally).
-```{ .python hl_lines="2 9" .no-check }
+!!! warning "Spark backend"
+
+ When processing from AND to a Spark DataFrame, the backend is automatically set to "spark".
+
+ We do NOT recommend using other backend when Spark dataframe are involved, as there may be a discrepancy between the time it takes to process the data locally and the timeout of the spark job.
+
+```{ .python hl_lines="2 12" .no-check }
# Read from the pyspark dataframe & use the omop converter
docs = edsnlp.data.from_spark(df, converter="omop")
# Add the pipeline to operations that will be run
-docs = docs.map_pipeline(nlp
+docs = docs.map_pipeline(nlp)
+
+# Backend is set by default to "spark"
+# docs = docs.set_processing(backend="spark")
# Convert each doc to a list of dicts (one by entity)
# and store the result in a pyspark DataFrame
note_nlp = docs.to_spark(
converter="ents",
span_getter=["ents", "dates"],
- span_attributes={
- "negation": "negation",
- "hypothesis": "hypothesis",
- "family": "family",
- "date.datetime": "datetime",
+ span_attributes=[
+ "negation",
+ "hypothesis",
+ "family",
+ "date.datetime",
# having individual columns for each date part
# can be useful for incomplete dates (eg, "in May")
- "date.day": "date_day",
- "date.month": "date_month",
- "date.year": "date_year",
- },
+ "date.day",
+ "date.month",
+ "date.year",
+ ],
dtypes=None, # (1)
)
```
diff --git a/docs/tutorials/training.md b/docs/tutorials/training.md
index c2d7daada..c069b366b 100644
--- a/docs/tutorials/training.md
+++ b/docs/tutorials/training.md
@@ -1,13 +1,12 @@
-# Training a Named Entity Recognition model
+# Training API
-In this tutorial, we'll see how we can train a deep learning model with EDS-NLP.
-We also recommend looking at an existing project as a reference, such as [eds-pseudo](https://github.com/aphp/eds-pseudo) or [mlg-norm](https://github.com/percevalw/mlg-norm).
+In this tutorial, we'll see how we can quickly train a deep learning model with EDS-NLP using the `edsnlp.train` function.
!!! warning "Hardware requirements"
Training a modern deep learning model requires a lot of computational resources. We recommend using a machine with a GPU, ideally with at least 16GB of VRAM. If you don't have access to a GPU, you can use a cloud service like [Google Colab](https://colab.research.google.com/), [Kaggle](https://www.kaggle.com/), [Paperspace](https://www.paperspace.com/) or [Vast.ai](https://vast.ai/).
-If you need a high level of control over the training procedure, we suggest you read the next ["Custom training script"](../make-a-training-script) tutorial.
+If you need a high level of control over the training procedure, we suggest you read the previous ["Deep learning tutorial"](./make-a-training-script.md) to understand how to build a training loop from scratch with EDS-NLP.
## Creating a project
@@ -38,7 +37,7 @@ readme = "README.md"
requires-python = ">3.7.1,<4.0"
dependencies = [
- "edsnlp[ml]>=0.13.0",
+ "edsnlp[ml]>=0.14.0",
"sentencepiece>=0.1.96"
]
@@ -117,24 +116,29 @@ EDS-NLP supports training models either [from the command line](#from-the-comman
scorer:
ner:
'@metrics': eds.ner_exact
- span_getter: ${nlp.components.ner.target_span_getter}
+ span_getter: ${ nlp.components.ner.target_span_getter }
# 🎛️ OPTIMIZER
optim:
- '@optimizers': adam
+ cls: adamw
groups:
- "*.transformer.*":
+ # Assign parameters starting with transformer (ie the parameters of the transformer component)
+ # to a first group
+ "^transformer":
lr: 5e-5
schedules:
'@schedules': linear
"warmup_rate": 0.1
"start_value": 0
- "*":
+ # And every other parameters to the second group
+ "":
lr: 3e-4
schedules:
'@schedules': linear
"warmup_rate": 0.1
"start_value": 3e-4
+ module: ${ nlp }
+ total_steps: ${ train.max_steps }
# 📚 DATA
train_data:
@@ -142,7 +146,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman
# In what kind of files (ie. their extensions) is our
# training data stored
'@readers': standoff
- path: ${vars.train}
+ path: ${ vars.train }
converter:
# What schema is used in the data files
- '@factory': eds.standoff_dict2doc
@@ -152,12 +156,13 @@ EDS-NLP supports training models either [from the command line](#from-the-comman
nlp: null
max_length: 2000
regex: '\n\n+'
- batch_size: 2000 words
+ shuffle: dataset
+ batch_size: 4096 tokens # 32 * 128 tokens
pipe_names: [ "ner" ]
val_data:
'@readers': standoff
- path: ${vars.dev}
+ path: ${ vars.dev }
# What schema is used in the data files
converter:
- '@factory': eds.standoff_dict2doc
@@ -183,7 +188,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman
# 📦 PACKAGE SCRIPT OPTIONS
# -> python -m edsnlp.package --config configs/config.yml
package:
- pipeline: ${train.output_dir}
+ pipeline: ${ train.output_dir }
name: 'my_ner_model'
```
@@ -201,7 +206,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman
```{ .python .no-check }
import edsnlp
- from edsnlp.training import train, create_optimizer, TrainingData
+ from edsnlp.training import train, ScheduledOptimizer, TrainingData
from edsnlp.metrics.ner import NerExactMetric
import edsnlp.pipes as eds
import torch
@@ -244,30 +249,31 @@ EDS-NLP supports training models either [from the command line](#from-the-comman
)
# 🎛️ OPTIMIZER
- # partial optimizer creation, but we could pass nlp=... and total_steps=...
- # to instantiate the optimizer directly
- optim = create_optimizer(
+ max_steps = 2000
+ optim = ScheduledOptimizer(
optim=torch.optim.Adam,
+ module=nlp,
+ total_steps=max_steps,
groups={
- "*.transformer.*": {
+ "^transformer": {
"lr": 5e-5,
"schedules": {"@schedules": "linear", "warmup_rate": 0.1, "start_value": 0},
},
- "*": {
+ "": {
"lr": 3e-4,
"schedules": {"@schedules": "linear", "warmup_rate": 0.1, "start_value": 3e-4},
},
- }
+ },
)
# 🚀 TRAIN
train(
nlp=nlp,
- max_steps=2000,
- validation_interval=200,
+ max_steps=max_steps,
+ validation_interval=max_steps // 10,
train_data=TrainingData(
data=train_data,
- batch_size="2000 words",
+ batch_size="4096 tokens", # 32 * 128 tokens
pipe_names=["ner"],
shuffle="dataset",
),
diff --git a/docs/tutorials/visualization.md b/docs/tutorials/visualization.md
index 4da82951a..0c681ea71 100644
--- a/docs/tutorials/visualization.md
+++ b/docs/tutorials/visualization.md
@@ -24,7 +24,15 @@ To print a text and highlight the entities in it, you can use `spacy.displacy`.
from spacy import displacy
doc = nlp(txt)
-displacy.render(doc, style="ent")
+options = {
+ # Optional colors
+
+ # "colors": {
+ # "covid": "orange",
+ # "respiratoire": "steelblue",
+ # },
+}
+displacy.render(doc, style="ent", options=options)
```
will render like this:
diff --git a/edsnlp/__init__.py b/edsnlp/__init__.py
index b504ba048..152cfa253 100644
--- a/edsnlp/__init__.py
+++ b/edsnlp/__init__.py
@@ -37,7 +37,7 @@ def find_spec(self, fullname, path, target=None): # pragma: no cover
spec = importlib.util.spec_from_loader(fullname, AliasLoader(new_name))
return spec
if fullname.startswith("edsnlp.optimization"):
- new_name = "edsnlp.training.optimization" + fullname[19:]
+ new_name = "edsnlp.training.optimizer" + fullname[19:]
spec = importlib.util.spec_from_loader(fullname, AliasLoader(new_name))
return spec
if fullname.startswith("edsnlp.scorers"):
diff --git a/edsnlp/core/stream.py b/edsnlp/core/stream.py
index 0b4a57793..cfca4dea4 100644
--- a/edsnlp/core/stream.py
+++ b/edsnlp/core/stream.py
@@ -1041,10 +1041,12 @@ def validate_ops(self, ops, update: bool = False):
def __repr__(self):
ops_str = ",\n".join(textwrap.indent(repr(op), " ") for op in self.ops)
+ if ops_str:
+ ops_str = "\n" + ops_str + "\n "
return (
f"Stream(\n"
f" reader={self.reader},\n"
- f" ops=[\n{ops_str}\n ],\n"
+ f" ops=[{ops_str}],\n"
f" writer={self.writer})\n"
)
diff --git a/edsnlp/data/converters.py b/edsnlp/data/converters.py
index 2551e8a8b..c8c262354 100644
--- a/edsnlp/data/converters.py
+++ b/edsnlp/data/converters.py
@@ -107,12 +107,10 @@ def validate_kwargs(func, kwargs):
class AttributesMappingArg(Validated):
"""
- A mapping from JSON attributes to Span extensions (can be a list too).
+ A span attribute mapping (can be a list too to keep the same names).
For instance:
- - `doc_attributes={"datetime": "note_datetime"}` will map the `datetime` JSON
- attribute to the `note_datetime` extension.
- `doc_attributes="note_datetime"` will map the `note_datetime` JSON attribute to
the `note_datetime` extension.
- `span_attributes=["negation", "family"]` will map the `negation` and `family` JSON
@@ -328,7 +326,7 @@ class StandoffDoc2DictConverter:
# Optional parameters
span_getter={"ents": True},
- span_attributes={"negation": "negated"},
+ span_attributes=["negation"],
)
# or docs.to_standoff(...) if it's already a
# [Stream][edsnlp.core.stream.Stream]
diff --git a/edsnlp/data/standoff.py b/edsnlp/data/standoff.py
index 228dcd1e9..b2ec6bca0 100644
--- a/edsnlp/data/standoff.py
+++ b/edsnlp/data/standoff.py
@@ -446,9 +446,8 @@ def read_standoff(
```{ .python .no-check }
doc_iterator = edsnlp.data.read_standoff(
"path/to/brat/directory",
- # Mapping from 'BRAT attribute name' to 'Doc attribute name'
- span_attributes={"Negation": "negated"},
- bool_attributes=["negated"], # Missing values will be set to False
+ span_attributes=["negation", "family"],
+ bool_attributes=["negation"], # Missing values will be set to False
)
```
diff --git a/edsnlp/pipes/misc/dates/dates.py b/edsnlp/pipes/misc/dates/dates.py
index 345e25c48..3b00719a8 100644
--- a/edsnlp/pipes/misc/dates/dates.py
+++ b/edsnlp/pipes/misc/dates/dates.py
@@ -104,7 +104,7 @@ class DatesMatcher(BaseNERComponent):
docs = docs.map_pipeline(nlp)
docs = docs.to_pandas(
converter="ents",
- span_attributes={"date.datetime": "datetime"},
+ span_attributes=["date.datetime"],
)
print(docs)
# note_id start end label lexical_variant span_type datetime
diff --git a/edsnlp/pipes/misc/quantities/quantities.py b/edsnlp/pipes/misc/quantities/quantities.py
index bc57d4d0c..ee9c1b66f 100644
--- a/edsnlp/pipes/misc/quantities/quantities.py
+++ b/edsnlp/pipes/misc/quantities/quantities.py
@@ -449,7 +449,7 @@ class QuantitiesMatcher(BaseNERComponent):
docs = docs.map_pipeline(nlp)
docs.to_pandas(
converter="ents",
- span_attributes={"value.unit": "original_unit", "value.kg": "kg"},
+ span_attributes=["value.unit", "value.kg"],
)
# note_id start end label lexical_variant span_type original_unit kg
# 0 None 18 27 weight 40000,0 g ents g 40.0
diff --git a/edsnlp/pipes/misc/split/split.py b/edsnlp/pipes/misc/split/split.py
index aaeaf2b4b..beab948f7 100644
--- a/edsnlp/pipes/misc/split/split.py
+++ b/edsnlp/pipes/misc/split/split.py
@@ -70,8 +70,6 @@ def subset_doc(doc: Doc, start: int, end: int) -> Doc:
shifter = make_shifter(start, end, new_doc)
- print(new_doc.user_data)
-
for key, val in list(new_doc.user_data.items()):
new_doc.user_data[key] = shifter(val)
diff --git a/edsnlp/training/__init__.py b/edsnlp/training/__init__.py
index b7713b0be..e87b73df9 100644
--- a/edsnlp/training/__init__.py
+++ b/edsnlp/training/__init__.py
@@ -1,2 +1,2 @@
from .trainer import train, GenericScorer, TrainingData
-from .optimization import ScheduledOptimizer, LinearSchedule, create_optimizer
+from .optimizer import ScheduledOptimizer, LinearSchedule
diff --git a/edsnlp/training/optimization.py b/edsnlp/training/optimization.py
deleted file mode 100644
index 827508c31..000000000
--- a/edsnlp/training/optimization.py
+++ /dev/null
@@ -1,217 +0,0 @@
-import fnmatch
-from collections import defaultdict
-from typing import Any, Dict, Optional
-
-import torch
-from confit import Config
-
-import edsnlp
-from edsnlp.utils.collections import get_deep_attr, set_deep_attr
-
-
-class ScheduledOptimizer(torch.optim.Optimizer):
- def __init__(self, optim, init_schedules: bool = True):
- self.optim = optim
- schedule_to_groups = defaultdict(lambda: [])
- for group in self.optim.param_groups:
- if "schedules" in group:
- group["schedules"] = (
- group["schedules"]
- if isinstance(group["schedules"], list)
- else [group["schedules"]]
- )
- group["schedules"] = list(group["schedules"])
- for schedule in group["schedules"]:
- schedule_to_groups[schedule].append(group)
- if init_schedules:
- schedule.step(group)
-
- def zero_grad(self):
- return self.optim.zero_grad()
-
- @property
- def param_groups(self):
- return self.optim.param_groups
-
- @param_groups.setter
- def param_groups(self, value):
- self.optim.param_groups = value
-
- @property
- def state(self):
- return self.optim.state
-
- @state.setter
- def state(self, value):
- self.optim.state = value
-
- def state_dict(self):
- state = {
- "optim": self.optim.state_dict(),
- "lr": [group.get("lr") for group in self.optim.param_groups],
- "schedules": [
- [schedule.state_dict() for schedule in group.get("schedules", ())]
- for group in self.optim.param_groups
- ],
- }
- for group in state["optim"]["param_groups"]:
- if "schedules" in group:
- del group["schedules"]
- return state
-
- def load_state_dict(self, state):
- optim_schedules = [
- group.get("schedules", ()) for group in self.optim.param_groups
- ]
- self.optim.load_state_dict(state["optim"])
- for group, group_schedule, group_schedules_state, lr in zip(
- self.optim.param_groups, optim_schedules, state["schedules"], state["lr"]
- ):
- group["schedules"] = group_schedule
- for schedule, schedule_state in zip(
- group["schedules"], group_schedules_state
- ):
- schedule.load_state_dict(schedule_state)
- group["lr"] = lr
-
- def step(self, closure=None):
- self.optim.step(closure=closure)
- for group in self.optim.param_groups:
- if "schedules" in group:
- for schedule in group["schedules"]:
- schedule.step(group)
-
- def initialize(self):
- self.reset()
- self.optim.step()
-
- def reset(self):
- self.optim.zero_grad()
- for group in self.optim.param_groups:
- for param in group["params"]:
- if param.requires_grad:
- param.grad = torch.zeros_like(param)
- for group in self.optim.param_groups:
- if "schedules" in group:
- for schedule in group["schedules"]:
- schedule.reset(group)
-
-
-class OptimizerGroupsProxy:
- def __init__(self, groups):
- self.param_groups = groups
-
-
-@edsnlp.registry.schedules.register("linear")
-class LinearSchedule:
- def __init__(
- self,
- total_steps: Optional[int] = None,
- max_value: Optional[Any] = None,
- start_value: float = 0.0,
- path: str = "lr",
- warmup_rate: float = 0.0,
- ):
- self.path = path
- self.start_value = start_value
- self.max_value = max_value
- self.warmup_rate = warmup_rate
- self.total_steps = total_steps
- self.idx = 0
-
- def reset(self, group):
- self.idx = -1
- self.step(group)
-
- def state_dict(self):
- return {
- "idx": self.idx,
- }
-
- def load_state_dict(self, state):
- self.idx = state["idx"]
-
- def step(self, group, closure=None):
- self.idx += 1
- if self.max_value is None:
- self.max_value = get_deep_attr(group, self.path)
- warmup_steps = self.total_steps * self.warmup_rate
- if self.idx < warmup_steps:
- progress = self.idx / warmup_steps
- value = self.start_value + (self.max_value - self.start_value) * progress
- else:
- progress = (self.idx - warmup_steps) / (self.total_steps - warmup_steps)
- value = self.max_value + (0 - self.max_value) * progress
- set_deep_attr(group, self.path, value)
-
- def __repr__(self):
- return (
- f"LinearSchedule(total_steps={self.total_steps}, "
- f"max_value={self.max_value}, "
- f"start_value={self.start_value}, "
- f"path={self.path}, "
- f"warmup_rate={self.warmup_rate})"
- )
-
-
-def create_optimizer(optim, **kwargs):
- def instantiate(nlp, total_steps=None):
- groups = list(nlp.parameters())
- named_parameters = list(nlp.named_parameters())
- if "groups" in kwargs:
- optim_groups = Config.resolve(
- kwargs.pop("groups"), registry=edsnlp.registry
- )
- optim_groups = {
- sel: dict(group) if group else False
- for sel, group in optim_groups.items()
- }
- all_matched_params = set()
- for sel, group in optim_groups.items():
- params = []
- for name, param in named_parameters:
- if param not in all_matched_params and fnmatch.fnmatch(name, sel):
- params.append(param)
- if group:
- group: Dict
- group["selector"] = sel
- group["params"] = params
- all_matched_params |= set(params)
- groups = [group for group in optim_groups.values() if group]
- instance = ScheduledOptimizer(
- optim(groups, **kwargs),
- init_schedules=False,
- )
- for group in instance.param_groups:
- if "schedules" in group:
- for schedule in group["schedules"]:
- if schedule.total_steps is None:
- assert total_steps is not None, (
- "total_steps must be provided to the optimizer "
- "or the schedule"
- )
- schedule.total_steps = total_steps
- schedule.step(group)
- return instance
-
- return (
- instantiate
- if "nlp" not in kwargs
- else instantiate(kwargs.pop("nlp"), kwargs.pop("total_steps", None))
- )
-
-
-for optim_name, optim in vars(torch.optim).items():
- if (
- isinstance(optim, type)
- and issubclass(optim, torch.optim.Optimizer)
- and optim is not torch.optim.Optimizer
- ):
-
- def wrapper(optim):
- def create_specific_optimizer(**kwargs):
- return create_optimizer(optim=optim, **kwargs)
-
- return create_specific_optimizer
-
- edsnlp.registry.optimizers.register(optim_name.lower(), func=wrapper(optim))
diff --git a/edsnlp/training/optimizer.py b/edsnlp/training/optimizer.py
new file mode 100644
index 000000000..900e49c66
--- /dev/null
+++ b/edsnlp/training/optimizer.py
@@ -0,0 +1,345 @@
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
+
+import pydantic
+import regex
+import torch
+import torch.optim
+from confit import Config, validate_arguments
+from typing_extensions import Literal
+
+import edsnlp
+from edsnlp.core import PipelineProtocol
+from edsnlp.utils.collections import get_deep_attr, set_deep_attr
+from edsnlp.utils.typing import AsList
+
+optim_mapping = {
+ "adam": torch.optim.Adam,
+ "adamw": torch.optim.AdamW,
+ "sgd": torch.optim.SGD,
+ "adadelta": torch.optim.Adadelta,
+ "adagrad": torch.optim.Adagrad,
+ "adamax": torch.optim.Adamax,
+ "rmsprop": torch.optim.RMSprop,
+}
+
+
+@validate_arguments
+class Schedule:
+ def step(self, group, closure=None):
+ raise NotImplementedError
+
+ def reset(self, group):
+ raise NotImplementedError()
+
+ def state_dict(self):
+ raise NotImplementedError()
+
+ def load_state_dict(self, state):
+ raise NotImplementedError()
+
+
+@edsnlp.registry.schedules.register("linear")
+class LinearSchedule(Schedule):
+ def __init__(
+ self,
+ total_steps: Optional[int] = None,
+ max_value: Optional[Any] = None,
+ start_value: float = 0.0,
+ path: str = "lr",
+ warmup_rate: float = 0.0,
+ ):
+ """
+ Linear schedule for a parameter group. The schedule will linearly increase
+ the value from `start_value` to `max_value` in the first `warmup_rate` of the
+ `total_steps` and then linearly decrease it to `0`.
+
+ Parameters
+ ----------
+ total_steps: Optional[int]
+ The total number of steps, usually used to calculate ratios.
+ max_value: Optional[Any]
+ The maximum value to reach.
+ start_value: float
+ The initial value.
+ path: str
+ The path to the attribute to set.
+ warmup_rate: float
+ The rate of the warmup.
+ """
+ self.path = path
+ self.start_value = start_value
+ self.max_value = max_value
+ self.warmup_rate = warmup_rate
+ self.total_steps = total_steps
+ self.idx = 0
+
+ def reset(self, group):
+ self.idx = -1
+ self.step(group)
+
+ def state_dict(self):
+ return {
+ "idx": self.idx,
+ }
+
+ def load_state_dict(self, state):
+ self.idx = state["idx"]
+
+ def step(self, group, closure=None):
+ self.idx += 1
+ if self.max_value is None:
+ self.max_value = get_deep_attr(group, self.path)
+ warmup_steps = self.total_steps * self.warmup_rate
+ if self.idx < warmup_steps:
+ progress = self.idx / warmup_steps
+ value = self.start_value + (self.max_value - self.start_value) * progress
+ else:
+ progress = (self.idx - warmup_steps) / (self.total_steps - warmup_steps)
+ value = self.max_value + (0 - self.max_value) * progress
+ set_deep_attr(group, self.path, value)
+
+ def __repr__(self):
+ format_string = type(self).__name__ + "(\n"
+ format_string += f" start_value: {self.start_value}\n"
+ format_string += f" max_value: {self.max_value}\n"
+ format_string += f" warmup_rate: {self.warmup_rate}\n"
+ format_string += f" path: {self.path}\n"
+ format_string += f" total_steps: {self.total_steps}\n"
+ format_string += ")"
+ return format_string
+
+
+class Group(pydantic.BaseModel, extra=pydantic.Extra.allow):
+ """
+ Parameter group for the optimizer.
+
+ Parameters
+ ----------
+ schedules : AsList[Schedule]
+ The schedules to apply to the group.
+ lr : Optional[float] = None
+ The learning rate for the group.
+ **kwargs
+ Additional parameters to pass to the group.
+ """
+
+ schedules: Optional[AsList[Schedule]] = None
+ lr: Optional[float] = None
+
+
+if TYPE_CHECKING:
+ Group = Dict
+
+
+@edsnlp.registry.misc.register("eds.scheduled_optimizer")
+class ScheduledOptimizer(torch.optim.Optimizer):
+ def __init__(
+ self,
+ cls: Union[torch.optim.Optimizer, Type[torch.optim.Optimizer], str],
+ module: Optional[Union[PipelineProtocol, torch.nn.Module]] = None,
+ total_steps: Optional[int] = None,
+ groups: Optional[Dict[str, Union[Group, Literal[False]]]] = None,
+ init_schedules: bool = True,
+ **kwargs,
+ ):
+ """
+ Wrapper optimizer that supports schedules for the parameters and easy parameter
+ selection using the key of the `groups` dictionary as regex patterns to match
+ the parameter names.
+
+ Parameters
+ ----------
+ cls : Union[str, Type[torch.optim.Optimizer], torch.optim.Optimizer]
+ The optimizer to use. If a string (like "adamw") or a type to instantiate,
+ the`module` and `groups` must be provided.
+ module : Optional[Union[PipelineProtocol, torch.nn.Module]]
+ The module to optimize. Usually the `nlp` pipeline object.
+ total_steps : Optional[int]
+ The total number of steps, used for schedules.
+ groups : Optional[Dict[str, Group]]
+ The groups to optimize. The key is a regex selector to match parameters in
+ `module.named_parameters()` and the value is a dictionary with the keys
+ `params` and `schedules`.
+
+ The matching is performed by running `regex.search(selector, name)` so you
+ do not have to match the full name. Note that the order of dict keys
+ matter. A parameter will be assigned to the first group that matches it, so
+ you can also exclude parameters by using a selector early in the groups and
+ putting `False` as the value.
+ """
+ should_instantiate_optim = isinstance(cls, str) or isinstance(cls, type)
+ if should_instantiate_optim and (groups is None or module is None):
+ raise ValueError(
+ "If the optimizer is a string or a type, the module and groups must "
+ "be provided."
+ )
+ elif not should_instantiate_optim and (
+ groups is not None or module is not None
+ ):
+ raise ValueError(
+ "If the optimizer is already instantiated, the module and groups must "
+ "not be provided."
+ )
+
+ if should_instantiate_optim:
+ named_parameters = list(module.named_parameters())
+ groups = Config.resolve(groups, registry=edsnlp.registry)
+ groups = {
+ sel: dict(group) if group else False for sel, group in groups.items()
+ }
+ all_matched_params = set()
+ for sel, group in groups.items():
+ params = []
+ for name, param in named_parameters:
+ if param not in all_matched_params and regex.search(sel, name):
+ params.append(param)
+ if group:
+ tmp_group = dict(group)
+ group.clear()
+ group: Dict
+ group["selector"] = sel
+ group["params"] = params
+ group.update(tmp_group)
+ all_matched_params |= set(params)
+ groups = [
+ {k: v for k, v in group.items() if v is not None}
+ for group in groups.values()
+ if group
+ ]
+
+ if isinstance(cls, str):
+ cls = (
+ optim_mapping[cls.lower()]
+ if cls.lower() in optim_mapping
+ else getattr(torch.optim, cls)
+ )
+ cls = cls(groups, **kwargs)
+
+ self.optim = cls
+ schedule_to_groups = defaultdict(lambda: [])
+ for group in self.optim.param_groups:
+ if "schedules" in group:
+ group["schedules"] = (
+ group["schedules"]
+ if isinstance(group["schedules"], list)
+ else [group["schedules"]]
+ )
+ group["schedules"] = list(group["schedules"])
+ for schedule in group["schedules"]:
+ schedule_to_groups[schedule].append(group)
+ if schedule.total_steps is None:
+ assert total_steps is not None, (
+ "total_steps must be provided to the optimizer "
+ "or the schedule"
+ )
+ schedule.total_steps = total_steps
+ if init_schedules:
+ schedule.step(group)
+
+ def zero_grad(self):
+ return self.optim.zero_grad()
+
+ @property
+ def param_groups(self):
+ return self.optim.param_groups
+
+ @param_groups.setter
+ def param_groups(self, value):
+ self.optim.param_groups = value
+
+ @property
+ def state(self):
+ return self.optim.state
+
+ @state.setter
+ def state(self, value):
+ self.optim.state = value
+
+ def state_dict(self):
+ state = {
+ "optim": self.optim.state_dict(),
+ "lr": [group.get("lr") for group in self.optim.param_groups],
+ "schedules": [
+ [schedule.state_dict() for schedule in group.get("schedules", ())]
+ for group in self.optim.param_groups
+ ],
+ }
+ for group in state["optim"]["param_groups"]:
+ if "schedules" in group:
+ del group["schedules"]
+ return state
+
+ def load_state_dict(self, state):
+ optim_schedules = [
+ group.get("schedules", ()) for group in self.optim.param_groups
+ ]
+ self.optim.load_state_dict(state["optim"])
+ for group, group_schedule, group_schedules_state, lr in zip(
+ self.optim.param_groups, optim_schedules, state["schedules"], state["lr"]
+ ):
+ group["schedules"] = group_schedule
+ for schedule, schedule_state in zip(
+ group["schedules"], group_schedules_state
+ ):
+ schedule.load_state_dict(schedule_state)
+ group["lr"] = lr
+
+ def step(self, closure=None):
+ self.optim.step(closure=closure)
+ for group in self.optim.param_groups:
+ if "schedules" in group:
+ for schedule in group["schedules"]:
+ schedule.step(group)
+
+ def initialize(self):
+ self.reset()
+ self.optim.step()
+
+ def reset(self):
+ self.optim.zero_grad()
+ for group in self.optim.param_groups:
+ for param in group["params"]:
+ if param.requires_grad:
+ param.grad = torch.zeros_like(param)
+ for group in self.optim.param_groups:
+ if "schedules" in group:
+ for schedule in group["schedules"]:
+ schedule.reset(group)
+
+ def __repr__(self):
+ format_string = type(self).__name__ + f"[{type(self.optim).__qualname__}] ("
+ ind = " "
+ for i, group in enumerate(self.param_groups):
+ format_string += "\n"
+ format_string += f"Parameter Group {i}\n"
+ keys = [
+ "selector",
+ "params",
+ "lr",
+ *sorted(set(group.keys()) - {"selector", "params", "lr", "schedules"}),
+ "schedules",
+ ]
+ for key in keys:
+ if key in group:
+ format_string += ind + f"{key}: "
+ if key == "params":
+ num_tensors = len(group["params"])
+ num_params = sum(p.numel() for p in group["params"])
+ format_string += (
+ f"{num_params} parameters ({num_tensors} tensors)\n"
+ )
+ elif isinstance(group[key], list):
+ format_string += "[\n"
+ for i, item in enumerate(group[key]):
+ sub_str = str(item)
+ for line in sub_str.split("\n"):
+ format_string += ind * 2 + line + "\n"
+ format_string += ind + "]\n"
+ else:
+ sub_str = str(group[key])
+ for i, line in enumerate(sub_str.split("\n")):
+ format_string += (ind if i > 0 else "") + line + "\n"
+
+ format_string += ")"
+ return format_string
diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py
index f19ac4265..d1f8c72b3 100644
--- a/edsnlp/training/trainer.py
+++ b/edsnlp/training/trainer.py
@@ -9,7 +9,6 @@
from typing import (
TYPE_CHECKING,
Any,
- Callable,
Collection,
Dict,
Iterable,
@@ -22,7 +21,6 @@
from confit import validate_arguments
from confit.utils.random import set_seed
from rich_logger import RichTablePrinter
-from torch.optim import Optimizer
from tqdm import tqdm, trange
from typing_extensions import Literal
@@ -37,10 +35,9 @@
from edsnlp.utils.span_getters import get_spans
from edsnlp.utils.typing import AsList
-from .optimization import ( # noqa: F401
+from .optimizer import ( # noqa: F401
LinearSchedule,
ScheduledOptimizer,
- create_optimizer,
)
LOGGER_FIELDS = {
@@ -298,7 +295,7 @@ def train(
val_data: AsList[Stream],
seed: int = 42,
max_steps: int = 1000,
- optim: Union[Optimizer, Optional[Callable[[Any, int], Optimizer]]] = None,
+ optim: Union[ScheduledOptimizer, torch.optim.Optimizer] = None,
validation_interval: int = 10,
max_grad_norm: float = 5.0,
loss_scales: Dict[str, float] = {},
@@ -336,8 +333,17 @@ def train(
The random seed
max_steps: int
The maximum number of training steps
- optim: Union[Optimizer, Optional[Callable[[Any, int], Optimizer]]]
+ optim: Union[ScheduledOptimizer, torch.optim.Optimizer]
The optimizer. If None, a default optimizer will be used.
+
+ ??? note "`ScheduledOptimizer` object/dictionary"
+ ::: edsnlp.training.optimizer.ScheduledOptimizer
+ options:
+ heading_level: 1
+ only_parameters: "no-header"
+ skip_parameters: []
+ show_source: false
+ show_toc: false
validation_interval: int
The number of steps between each evaluation
max_grad_norm: float
@@ -418,14 +424,8 @@ def train(
if k in ("task_lr", "transformer_lr", "warmup_rate")
},
)
- else:
- optim = (
- optim_base
- if isinstance(optim_base, Optimizer)
- else optim_base(nlp, max_steps)
- )
- if hasattr(optim, "reset"):
- optim.reset()
+ if hasattr(optim, "initialize"):
+ optim.initialize()
grad_params = {p for group in optim.param_groups for p in group["params"]}
print(
"Optimizing groups:"
diff --git a/edsnlp/utils/batching.py b/edsnlp/utils/batching.py
index 301e7ce63..8a70941c3 100644
--- a/edsnlp/utils/batching.py
+++ b/edsnlp/utils/batching.py
@@ -370,6 +370,46 @@ def batchify_by_fragment(
def stat_batchify(key):
+ """
+ Create a batching function that uses the value of a specific key in the items to
+ determine the batch size. This function is primarily meant to be used on the
+ flattened outputs of the `preprocess` method of a
+ [Pipeline][edsnlp.core.pipeline.Pipeline] object.
+
+ It expects the items to be a dictionary in which some keys contain the string
+ "/stats/" and the `key` pattern. For instance:
+
+ ```python
+ from edsnlp.utils.batching import stat_batchify
+
+ items = [
+ {"text": "first sample", "obj/stats/words": 2, "obj/stats/chars": 12},
+ {"text": "dos", "obj/stats/words": 1, "obj/stats/chars": 3},
+ {"text": "third one !", "obj/stats/words": 3, "obj/stats/chars": 11},
+ ]
+ batcher = stat_batchify("words")
+ assert list(batcher(items, 4)) == [
+ [
+ {"text": "first sample", "obj/stats/words": 2, "obj/stats/chars": 12},
+ {"text": "dos", "obj/stats/words": 1, "obj/stats/chars": 3},
+ ],
+ [
+ {"text": "third one !", "obj/stats/words": 3, "obj/stats/chars": 11},
+ ],
+ ]
+ ```
+
+
+ Parameters
+ ----------
+ key: str
+ The key pattern to use to determine the actual key to look up in the items.
+
+ Returns
+ -------
+ Callable[[Iterable, int, bool, Literal["drop", "split"]], Iterable
+ """
+
def rec(
iterable,
batch_size,
diff --git a/mkdocs.yml b/mkdocs.yml
index 8e6593e47..023a92c4f 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -45,15 +45,15 @@ nav:
- tutorials/spacy101.md
- tutorials/matching-a-terminology.md
- tutorials/qualifying-entities.md
+ - tutorials/visualization.md
- tutorials/detecting-dates.md
- - tutorials/multiple-texts.md
- tutorials/reason.md
- tutorials/endlines.md
- tutorials/aggregating-results.md
+ - tutorials/multiple-texts.md
- advanced-tutorials/fastapi.md
- - tutorials/training.md
- tutorials/make-a-training-script.md
- - tutorials/visualization.md
+ - tutorials/training.md
- Pipes:
- Overview: pipes/index.md
- Core Pipelines:
@@ -216,6 +216,7 @@ plugins:
merge_init_into_class: true
- bibtex:
bibtex_file: "docs/references.bib"
+ - clickable_snippets:
# Just uncomment the following lines to enable i18n
# and start creating .fr.md and .en.md files.
# - i18n:
diff --git a/pyproject.toml b/pyproject.toml
index 4b65a2b24..b4f712736 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,6 +73,8 @@ dev = [
"pathspec>=0.11.1", # required by vendored mkdocs-autorefs PR
"astunparse",
"griffe<0.39",
+ "jedi",
+ "html5lib",
]
docs = [
"mike~=1.1.2",
@@ -89,6 +91,8 @@ docs = [
"pathspec>=0.11.1", # required by vendored mkdocs-autorefs PR
"astunparse",
"griffe<0.39",
+ "jedi",
+ "html5lib",
"torch>=1.13.0",
"foldedtensor>=0.3.2",
@@ -300,6 +304,7 @@ where = ["."]
[project.entry-points."spacy_misc"]
"eds.span_context_getter" = "edsnlp.utils.span_getters:make_span_context_getter"
+"eds.scheduled_optimizer" = "edsnlp.training.optimizer:ScheduledOptimizer"
[project.entry-points."spacy_languages"]
"eds" = "edsnlp.language:EDSLanguage"
@@ -310,6 +315,7 @@ where = ["."]
[project.entry-points."mkdocs.plugins"]
"bibtex" = "docs.scripts.bibtex:BibTexPlugin"
"autorefs" = "docs.scripts.autorefs.plugin:AutorefsPlugin"
+"clickable_snippets" = "docs.scripts.clickable_snippets:ClickableSnippetsPlugin"
[build-system]
requires = [
diff --git a/tests/training/ner_qlf_config.yml b/tests/training/ner_qlf_config.yml
index 2dad6da2d..6b25c7d9f 100644
--- a/tests/training/ner_qlf_config.yml
+++ b/tests/training/ner_qlf_config.yml
@@ -59,10 +59,11 @@ scorer:
# 🎛️ OPTIMIZER
optim:
- '@optimizers': adam
+ cls: AdamW
+ module: ${ nlp }
groups:
- "*.transformer.*": false
- "*":
+ "^transformer": false
+ ".*":
lr: 1e-3
# 📚 DATA
diff --git a/tests/training/test_optimizer.py b/tests/training/test_optimizer.py
new file mode 100644
index 000000000..a5585309d
--- /dev/null
+++ b/tests/training/test_optimizer.py
@@ -0,0 +1,136 @@
+import pytest
+import torch
+
+from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer
+
+
+class Net(torch.nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.fc1 = torch.nn.Linear(10, 1)
+ self.fc2 = torch.nn.Linear(1, 1)
+
+ def forward(self, x):
+ return self.fc(x)
+
+
+@pytest.fixture(scope="module")
+def net():
+ net = Net()
+ return net
+
+
+def test_parameter_selection(net):
+ optim = ScheduledOptimizer(
+ cls="adamw",
+ module=net,
+ groups={
+ "fc1[.].*": {
+ "lr": 0.1,
+ "weight_decay": 0.01,
+ "schedules": [
+ {"@schedules": "linear", "start_value": 0.0, "warmup_rate": 0.2},
+ ],
+ },
+ "fc2[.]bias": False,
+ "": {
+ "lr": 0.0001,
+ "weight_decay": 0.0,
+ },
+ },
+ total_steps=10,
+ )
+ assert len(optim.state) == 0
+ optim.initialize()
+ assert all([p in optim.state for p in net.fc1.parameters()])
+
+ fc1_group = optim.param_groups[0]
+ assert fc1_group["lr"] == pytest.approx(0.0)
+ assert fc1_group["weight_decay"] == pytest.approx(0.01)
+ assert set(fc1_group["params"]) == {net.fc1.weight, net.fc1.bias}
+
+ fc2_group = optim.param_groups[1]
+ assert fc2_group["lr"] == pytest.approx(0.0001)
+ assert set(fc2_group["params"]) == {net.fc2.weight}
+
+ lr_values = [fc1_group["lr"]]
+
+ for i in range(10):
+ optim.step()
+ lr_values.append(fc1_group["lr"])
+
+ assert lr_values == pytest.approx(
+ [
+ 0.0,
+ 0.05,
+ 0.1,
+ 0.0875,
+ 0.075,
+ 0.0625,
+ 0.05,
+ 0.0375,
+ 0.025,
+ 0.0125,
+ 0.0,
+ ]
+ )
+
+
+def test_serialization(net):
+ optim = ScheduledOptimizer(
+ cls="adamw",
+ module=net,
+ groups={
+ "fc1[.].*": {
+ "lr": 0.1,
+ "weight_decay": 0.01,
+ "schedules": LinearSchedule(start_value=0.0, warmup_rate=0.2),
+ },
+ "fc2[.]bias": False,
+ "": {
+ "lr": 0.0001,
+ "weight_decay": 0.0,
+ },
+ },
+ total_steps=10,
+ )
+ optim.initialize()
+ optim.param_groups = optim.param_groups
+
+ state_dict = None
+ for i in range(10):
+ if i == 5:
+ state_dict = optim.state_dict()
+ optim.step()
+
+ assert optim.param_groups[0]["lr"] == pytest.approx(0.0)
+ optim.load_state_dict(state_dict)
+ assert optim.param_groups[0]["lr"] == pytest.approx(0.0625)
+
+ optim.reset()
+
+
+def test_repr(net):
+ optim = ScheduledOptimizer(
+ cls="adamw",
+ module=net,
+ groups={
+ "fc1[.].*": {
+ "lr": 0.1,
+ "weight_decay": 0.01,
+ "schedules": [
+ LinearSchedule(start_value=0.0, warmup_rate=0.2),
+ LinearSchedule(path="weight_decay"),
+ ],
+ },
+ "fc2[.]bias": False,
+ ".*": {
+ "lr": 0.0001,
+ "weight_decay": 0.0,
+ },
+ },
+ total_steps=10,
+ )
+ optim.initialize()
+
+ assert "ScheduledOptimizer[AdamW]" in repr(optim)
diff --git a/tests/training/test_train.py b/tests/training/test_train.py
index 0873878d6..e2c49033b 100644
--- a/tests/training/test_train.py
+++ b/tests/training/test_train.py
@@ -20,10 +20,7 @@
from edsnlp.core.registries import registry
from edsnlp.data.converters import AttributesMappingArg, get_current_tokenizer
-from edsnlp.training.optimization import (
- LinearSchedule,
- create_optimizer,
-)
+from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer
from edsnlp.training.trainer import GenericScorer, train
from edsnlp.utils.span_getters import SpanSetterArg, set_spans
@@ -122,10 +119,12 @@ def test_qualif_train(run_in_test_dir, tmp_path):
def test_optimizer():
net = torch.nn.Linear(10, 10)
- optim = create_optimizer(
+ optim = ScheduledOptimizer(
torch.optim.AdamW,
+ module=net,
+ total_steps=10,
groups={
- "*": {
+ ".*": {
"lr": 9e-4,
"schedules": LinearSchedule(
warmup_rate=0.1,
@@ -133,7 +132,7 @@ def test_optimizer():
),
}
},
- )(net, total_steps=10)
+ )
for param in net.parameters():
assert "exp_avg" not in optim.optim.state[param]
optim.initialize()