Skip to content

Commit

Permalink
feat: new eds.span_linker component
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Mar 29, 2024
1 parent 2563722 commit d5dc0d8
Show file tree
Hide file tree
Showing 12 changed files with 857 additions and 6 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
```

*The previous way of adding pipes is still supported.*
- New `eds.span_linker` deep-learning component to match entities with their concepts in a knowledge base, in synonym-similarity or concept-similarity mode.

### Changed

Expand Down
Binary file added docs/assets/images/class_span_linker.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/synonym_span_linker.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions docs/pipes/trainable/span-linker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Trainable Span Linker {: #edsnlp.pipes.trainable.span_linker.factory.create_component }

::: edsnlp.pipes.trainable.span_linker.factory.create_component
options:
heading_level: 2
show_bases: false
show_source: false
only_class_level: true
10 changes: 10 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ @inproceedings{dalloux2017ESSAI
year = {2017}
}

@article{wajsburt2021medical,
title={Medical concept normalization in French using multilingual terminologies and contextual embeddings},
author={Wajsbürt, Perceval and Sarfati, Arnaud and Tannier, Xavier},
journal={Journal of Biomedical Informatics},
volume={114},
pages={103684},
year={2021},
url = {https://doi.org/10.1016/j.jbi.2021.103684},
publisher={Elsevier}
}

@phdthesis{wajsburt:tel-03624928,
TITLE = {{Extraction and normalization of simple and structured entities in medical documents}},
Expand Down
1 change: 1 addition & 0 deletions edsnlp/pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .qualifiers.reported_speech.factory import create_component as rspeech
from .trainable.ner_crf.factory import create_component as ner_crf
from .trainable.span_qualifier.factory import create_component as span_qualifier
from .trainable.span_linker.factory import create_component as span_linker
from .trainable.embeddings.span_pooler.factory import create_component as span_pooler
from .trainable.embeddings.transformer.factory import create_component as transformer
from .trainable.embeddings.text_cnn.factory import create_component as text_cnn
102 changes: 102 additions & 0 deletions edsnlp/pipes/trainable/layers/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Optional

import torch
import torch.nn.functional as F
from typing_extensions import Literal


class Metric(torch.nn.Module):
"""
Metric layer, used for computing similarities between two sets of vectors. A typical
use case is to compute the similarity between a set of query vectors (input
embeddings) and a set of concept vectors (output embeddings).
Parameters
----------
in_features : int
Size of the input embeddings
out_features : int
Size of the output embeddings
num_groups : int
Number of groups for the output embeddings, that can be used to filter out
certain concepts that are not relevant for a given query (e.g. do not compare
a drug with concepts for diseases)
metric : Literal["cosine", "dot"]
Whether to compute the cosine similarity between the input and output embeddings
or the dot product.
rescale: Optional[float]
Rescale the output cosine similarities by a constant factor.
"""

def __init__(
self,
in_features: int,
out_features: int,
num_groups: int = 0,
metric: Literal["cosine", "dot"] = "cosine",
rescale: Optional[float] = None,
bias: bool = True,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features))
self.register_buffer(
"groups", torch.zeros(num_groups, out_features, dtype=torch.bool)
)
self.rescale: float = (
rescale if rescale is not None else 20.0 if metric == "cosine" else 1.0
)
self.metric = metric
self.register_parameter(
"bias",
torch.nn.Parameter(torch.tensor(-0.65 if metric == "cosine" else 0.0))
if bias
else None,
)
self.reset_parameters()

self._last_version = None
self._normalized_weight = None

def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.weight)

def normalized_weight(self):
if (
(self.weight._version, id(self.weight)) == self._last_version
and not self.training
and self._normalized_weight is not None
):
return self._normalized_weight
normalized_weight = self.normalize_embedding(self.weight)
if not self.training and normalized_weight is not self.weight:
self._normalized_weight = normalized_weight
self._last_version = (self.weight._version, id(self.weight))
return normalized_weight

def normalize_embedding(self, inputs):
if self.metric == "cosine":
inputs = F.normalize(inputs, dim=-1)
return inputs

def forward(self, inputs, group_indices=None, **kwargs):
x = F.linear(
self.normalize_embedding(inputs),
self.normalized_weight(),
)
if self.bias is not None:
x += self.bias
if self.rescale != 1.0:
x *= self.rescale
if group_indices is not None and len(self.groups):
x = x.masked_fill(~self.groups[group_indices], -10000)
return x

def extra_repr(self):
return "in_features={}, out_features={}, rescale={}, groups={}".format(
self.in_features,
self.out_features,
float(self.rescale or 1.0),
self.groups.shape[0] if self.groups is not None else None,
)
1 change: 1 addition & 0 deletions edsnlp/pipes/trainable/span_linker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .factory import create_component
14 changes: 14 additions & 0 deletions edsnlp/pipes/trainable/span_linker/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import TYPE_CHECKING

from edsnlp import registry

from .span_linker import TrainableSpanLinker

create_component = registry.factory.register(
"eds.span_linker",
assigns=[],
deprecated=[],
)(TrainableSpanLinker)

if TYPE_CHECKING:
create_component = TrainableSpanLinker
Loading

0 comments on commit d5dc0d8

Please sign in to comment.