Skip to content

Commit

Permalink
feat: new eds.extractive_qa component
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed May 12, 2024
1 parent 14d1728 commit c53515e
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster.
- New `eds.extractive_qa` component to perform extractive question answering using questions as prompts to tag entities instead of a list of predefined labels as in `eds.ner_crf`.

### Fixed

Expand Down
8 changes: 8 additions & 0 deletions docs/pipes/trainable/extractive-qa.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Extractive Question Answering {: #edsnlp.pipes.trainable.extractive_qa.factory.create_component }

::: edsnlp.pipes.trainable.extractive_qa.factory.create_component
options:
heading_level: 2
show_bases: false
show_source: false
only_class_level: true
1 change: 1 addition & 0 deletions docs/pipes/trainable/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ All trainable components implement the [`TorchComponent`][edsnlp.core.torch_comp
| `eds.text_cnn` | Contextualize embeddings with a CNN |
| `eds.span_pooler` | A span embedding component that aggregates word embeddings |
| `eds.ner_crf` | A trainable component to extract entities |
| `eds.extractive_qa` | A trainable component for extractive question answering |
| `eds.span_classifier` | A trainable component for multi-class multi-label span classification |
| `eds.span_linker` | A trainable entity linker (i.e. to a list of concepts) |

Expand Down
1 change: 1 addition & 0 deletions edsnlp/pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .qualifiers.reported_speech.factory import create_component as reported_speech
from .qualifiers.reported_speech.factory import create_component as rspeech
from .trainable.ner_crf.factory import create_component as ner_crf
from .trainable.extractive_qa.factory import create_component as extractive_qa
from .trainable.span_classifier.factory import create_component as span_classifier
from .trainable.span_linker.factory import create_component as span_linker
from .trainable.embeddings.span_pooler.factory import create_component as span_pooler
Expand Down
Empty file.
248 changes: 248 additions & 0 deletions edsnlp/pipes/trainable/extractive_qa/extractive_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Set

from spacy.tokens import Doc, Span
from typing_extensions import Literal

from edsnlp.core.pipeline import Pipeline
from edsnlp.pipes.trainable.embeddings.typing import (
WordEmbeddingComponent,
)
from edsnlp.pipes.trainable.ner_crf.ner_crf import NERBatchOutput, TrainableNerCrf
from edsnlp.utils.filter import align_spans, filter_spans
from edsnlp.utils.span_getters import (
SpanGetterArg,
SpanSetterArg,
get_spans,
)
from edsnlp.utils.typing import AsList


class TrainableExtractiveQA(TrainableNerCrf):
"""
The `eds.extractive_qa` component is a trainable extractive question answering
component. This can be seen as a Named Entity Recognition (NER) component where the
types of entities predicted by the model are not pre-defined during the training
but are provided as prompts (i.e., questions) at inference time.
The `eds.extractive_qa` shares a lot of similarities with the `eds.ner_crf`
component, and therefore most of the arguments are the same.
!!! note "Extractive vs Abstractive Question Answering"
Extractive Question Answering differs from Abstractive Question Answering in
that the answer is extracted from the text, rather than generated (à la
ChatGPT) from scratch. To normalize the answers, you can use the
`eds.span_linker` component in `synonym` mode and search for the closest
`synonym` in a predefined list.
Examples
--------
```python
import edsnlp, edsnlp.pipes as eds
nlp = edsnlp.blank("eds")
nlp.add_pipe(
eds.extractive_qa(
embedding=eds.transformer(
model="prajjwal1/bert-tiny",
window=128,
stride=96,
),
mode="joint",
target_span_getter="ner-gold",
span_setter="ents",
questions={
"disease": "What disease does the patient have?",
"drug": "What drug is the patient taking?",
},
),
name="ner",
)
```
To train the model, refer to the [Training](/tutorials/make-a-training-script)
tutorial.
Parameters
----------
name : str
Name of the component
embedding : WordEmbeddingComponent
The word embedding component
questions : Dict[str, AsList[str]]
The questions to ask, as a mapping between the entity type and the list of
questions to ask for this entity type (or single string if only one question).
questions_attribute : Optional[str]
The attribute to use to get the questions dynamically from the Doc or Span
objects (as returned by the `context_getter` argument). If None, the questions
will be fixed and only taken from the `questions` argument.
context_getter : Optional[SpanGetterArg]
What context to use when computing the span embeddings (defaults to the whole
document). For example `{"section": "conclusion"}` to only extract the
entities from the conclusion.
target_span_getter : SpanGetterArg
Method to call to get the gold spans from a document, for scoring or training.
By default, takes all entities in `doc.ents`, but we recommend you specify
a given span group name instead.
span_setter : Optional[SpanSetterArg]
The span setter to use to set the predicted spans on the Doc object. If None,
the component will infer the span setter from the target_span_getter config.
infer_span_setter : Optional[bool]
Whether to complete the span setter from the target_span_getter config.
False by default, unless the span_setter is None.
mode : Literal["independent", "joint", "marginal"]
The CRF mode to use : independent, joint or marginal
window : int
The window size to use for the CRF. If 0, will use the whole document, at
the cost of a longer computation time. If 1, this is equivalent to assuming
that the tags are independent and will the component be faster, but with
degraded performance. Empirically, we found that a window size of 10 or 20
works well.
stride : Optional[int]
The stride to use for the CRF windows. Defaults to `window // 2`.
"""

def __init__(
self,
nlp: Optional[Pipeline] = None,
name: Optional[str] = "extractive_qa",
*,
embedding: WordEmbeddingComponent,
questions: Dict[str, AsList[str]] = {},
questions_attribute: str = "questions",
context_getter: Optional[SpanGetterArg] = None,
target_span_getter: Optional[SpanGetterArg] = None,
span_setter: Optional[SpanSetterArg] = None,
infer_span_setter: Optional[bool] = None,
mode: Literal["independent", "joint", "marginal"] = "joint",
window: int = 40,
stride: Optional[int] = None,
):
self.questions_attribute: Optional[str] = questions_attribute
self.questions = questions
super().__init__(
nlp=nlp,
name=name,
embedding=embedding,
context_getter=context_getter,
span_setter=span_setter,
target_span_getter=target_span_getter,
mode=mode,
window=window,
stride=stride,
infer_span_setter=infer_span_setter,
)
self.update_labels(["answer"])
self.labels_to_idx = defaultdict(lambda: 0)

def set_extensions(self):
super().set_extensions()
if self.questions_attribute:
if not Doc.has_extension(self.questions_attribute):
Doc.set_extension(self.questions_attribute, default=None)
if not Span.has_extension(self.questions_attribute):
Span.set_extension(self.questions_attribute, default=None)

def post_init(self, docs: Iterable[Doc], exclude: Set[str]):
pass

@property
def cfg(self):
cfg = dict(super().cfg)
cfg.pop("labels")
return cfg

def preprocess(self, doc, **kwargs):
contexts = (
list(get_spans(doc, self.context_getter))
if self.context_getter
else [doc[:]]
)
prompt_contexts_and_labels = sorted(
{
(prompt, label, context)
for context in contexts
for label, questions in (
*self.questions.items(),
*(getattr(doc._, self.questions_attribute) or {}).items(),
*(
(getattr(context._, self.questions_attribute) or {}).items()
if context is not doc
else ()
),
)
for prompt in questions
}
)
questions = [x[0] for x in prompt_contexts_and_labels]
labels = [x[1] for x in prompt_contexts_and_labels]
ctxs = [x[2] for x in prompt_contexts_and_labels]
return {
"lengths": [len(ctx) for ctx in ctxs],
"$labels": labels,
"$contexts": ctxs,
"embedding": self.embedding.preprocess(
doc,
contexts=ctxs,
prompts=questions,
**kwargs,
),
}

def preprocess_supervised(self, doc, **kwargs):
prep = self.preprocess(doc, **kwargs)
contexts = prep["$contexts"]
labels = prep["$labels"]
tags = []

for context, label, target_ents in zip(
contexts,
labels,
align_spans(
list(get_spans(doc, self.target_span_getter)),
contexts,
),
):
span_tags = [[0] * len(self.labels) for _ in range(len(context))]
start = context.start
target_ents = [ent for ent in target_ents if ent.label_ == label]

# TODO: move this to the LinearChainCRF class
for ent in filter_spans(target_ents):
label_idx = self.labels_to_idx[ent.label_]
if ent.start == ent.end - 1:
span_tags[ent.start - start][label_idx] = 4
else:
span_tags[ent.start - start][label_idx] = 2
span_tags[ent.end - 1 - start][label_idx] = 3
for i in range(ent.start + 1 - start, ent.end - 1 - start):
span_tags[i][label_idx] = 1
tags.append(span_tags)

return {
**prep,
"targets": tags,
}

def postprocess(
self,
docs: List[Doc],
results: NERBatchOutput,
inputs: List[Dict[str, Any]],
):
spans: Dict[Doc, list[Span]] = defaultdict(list)
contexts = [ctx for sample in inputs for ctx in sample["$contexts"]]
labels = [label for sample in inputs for label in sample["$labels"]]
tags = results["tags"].cpu()
for context_idx, _, start, end in self.crf.tags_to_spans(tags).tolist():
span = contexts[context_idx][start:end]
label = labels[context_idx]
span.label_ = label
spans[span.doc].append(span)
for doc in docs:
self.set_spans(doc, spans.get(doc, []))
return docs
14 changes: 14 additions & 0 deletions edsnlp/pipes/trainable/extractive_qa/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import TYPE_CHECKING

from edsnlp import registry

from .extractive_qa import TrainableExtractiveQA

create_component = registry.factory.register(
"eds.extractive_qa",
assigns=[],
deprecated=[],
)(TrainableExtractiveQA)

if TYPE_CHECKING:
create_component = TrainableExtractiveQA
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ nav:
- 'NER': pipes/trainable/ner.md
- 'Span Classifier': pipes/trainable/span-classifier.md
- 'Span Linker': pipes/trainable/span-linker.md
- 'Extractive QA': pipes/trainable/extractive-qa.md
- tokenizers.md
- Data Connectors:
- data/index.md
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,9 @@ where = ["."]
"eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component"
"eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component"
"eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component"
"eds.extractive_qa" = "edsnlp.pipes.trainable.extractive_qa.factory:create_component"
"eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component"
"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component"
"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component"
"eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component"
"eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component"

Expand Down
73 changes: 73 additions & 0 deletions tests/pipelines/trainable/test_extractive_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from spacy.tokens import Span

import edsnlp
import edsnlp.pipes as eds


def test_ner():
nlp = edsnlp.blank("eds")
nlp.add_pipe(
eds.extractive_qa(
embedding=eds.transformer(
model="prajjwal1/bert-tiny",
window=20,
stride=10,
),
# During training, where do we get the gold entities from ?
target_span_getter=["ner-gold"],
# Which prompts for each label ?
questions={
"PERSON": "Quels sont les personnages ?",
"GIFT": "Quels sont les cadeaux ?",
},
questions_attribute="question",
# During prediction, where do we set the predicted entities ?
span_setter="ents",
),
)

doc = nlp(
"L'aîné eut le Moulin, le second eut l'âne, et le plus jeune n'eut que le Chat."
)
doc._.question = {
"FAVORITE": ["Qui a eu de l'argent ?"],
}
# doc[0:2], doc[4:5], doc[6:8], doc[9:11], doc[13:16], doc[20:21]
doc.spans["ner-gold"] = [
Span(doc, 0, 2, "PERSON"), # L'aîné
Span(doc, 4, 5, "GIFT"), # Moulin
Span(doc, 6, 8, "PERSON"), # le second
Span(doc, 9, 11, "GIFT"), # l'âne
Span(doc, 13, 16, "PERSON"), # le plus jeune
Span(doc, 20, 21, "GIFT"), # Chat
]
nlp.post_init([doc])

ner = nlp.pipes.extractive_qa
batch = ner.prepare_batch([doc], supervision=True)
results = ner.module_forward(batch)

pred = list(ner.pipe([doc]))[0]
print(pred.ents)

assert results["loss"] is not None
trf_inputs = [
seq.replace(" [PAD]", "")
for seq in ner.embedding.tokenizer.batch_decode(batch["embedding"]["input_ids"])
]
assert trf_inputs == [
"[CLS] quels sont les cadeaux? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501
"[CLS] quels sont les cadeaux? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501
"[CLS] quels sont les cadeaux? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501
"[CLS] quels sont les personnages? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501
"[CLS] quels sont les personnages? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501
"[CLS] quels sont les personnages? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501
"[CLS] qui a eu de l'argent? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501
"[CLS] qui a eu de l'argent? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501
"[CLS] qui a eu de l'argent? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501
]
assert batch["targets"].squeeze(2).tolist() == [
[0, 0, 0, 0, 4, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0],
[2, 3, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 2, 1, 3, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]

0 comments on commit c53515e

Please sign in to comment.