-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
41 changed files
with
3,541 additions
and
267 deletions.
There are no files selected for viewing
File renamed without changes.
63 changes: 63 additions & 0 deletions
63
bpm-ai-core/bpm_ai_core/classification/transformers_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import logging | ||
|
||
from bpm_ai_core.classification.zero_shot_classifier import ZeroShotClassifier, ClassificationResult | ||
|
||
try: | ||
from transformers import pipeline, AutoTokenizer | ||
has_transformers = True | ||
except ImportError: | ||
has_transformers = False | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DEFAULT_MODEL_EN = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33" | ||
DEFAULT_MODEL_MULTI = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" | ||
|
||
|
||
class TransformersClassifier(ZeroShotClassifier): | ||
""" | ||
Local zero-shot classification model based on Huggingface transformers library. | ||
To use, you should have the ``transformers`` python package installed. | ||
""" | ||
|
||
def __init__(self, model: str = DEFAULT_MODEL_EN): | ||
if not has_transformers: | ||
raise ImportError('transformers is not installed') | ||
self.model = model | ||
|
||
def classify_with_metadata( | ||
self, | ||
text: str, | ||
classes: list[str], | ||
hypothesis_template: str | None = None | ||
) -> ClassificationResult: | ||
zeroshot_classifier = pipeline("zero-shot-classification", model=self.model) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(self.model) | ||
input_tokens = len(tokenizer.encode(text)) | ||
max_tokens = tokenizer.model_max_length | ||
logger.debug(f"Input tokens: {input_tokens}") | ||
if input_tokens > max_tokens: | ||
logger.warning( | ||
f"Input tokens exceed max model context size: {input_tokens} > {max_tokens}. Input will be truncated." | ||
) | ||
|
||
prediction = zeroshot_classifier( | ||
text, | ||
classes, | ||
hypothesis_template=hypothesis_template or "This example is about {}", | ||
multi_label=False | ||
) | ||
# Zip the labels and scores together and find the label with the max score | ||
labels_scores = list(zip(prediction['labels'], prediction['scores'])) | ||
max_label, max_score = max(labels_scores, key=lambda x: x[1]) | ||
|
||
return ClassificationResult( | ||
max_label=max_label, | ||
max_score=max_score, | ||
labels_scores=labels_scores | ||
) | ||
|
||
|
||
|
51 changes: 51 additions & 0 deletions
51
bpm-ai-core/bpm_ai_core/classification/zero_shot_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Tuple | ||
|
||
from pydantic import BaseModel | ||
|
||
from bpm_ai_core.tracing.tracing import Tracing | ||
|
||
|
||
class ClassificationResult(BaseModel): | ||
max_label: str | ||
max_score: float | ||
labels_scores: list[Tuple[str, float]] | ||
|
||
|
||
class ZeroShotClassifier(ABC): | ||
""" | ||
Zero Shot Classification Model | ||
""" | ||
|
||
@abstractmethod | ||
def classify_with_metadata( | ||
self, | ||
text: str, | ||
classes: list[str], | ||
hypothesis_template: str | None = None | ||
) -> ClassificationResult: | ||
pass | ||
|
||
def classify( | ||
self, | ||
text: str, | ||
classes: list[str], | ||
confidence_threshold: float | None = None, | ||
hypothesis_template: str | None = None | ||
) -> str: | ||
Tracing.tracers().start_span("classification", inputs={ | ||
"text": text, | ||
"classes": classes, | ||
"confidence_threshold": confidence_threshold, | ||
"hypothesis_template": hypothesis_template | ||
}) | ||
result = self.classify_with_metadata( | ||
text=text, | ||
classes=classes, | ||
hypothesis_template=hypothesis_template | ||
) | ||
Tracing.tracers().end_span(outputs={"result": result.model_dump()}) | ||
# Only return the label if the score is above the threshold (if given) | ||
return result.max_label \ | ||
if not confidence_threshold or result.max_score > confidence_threshold \ | ||
else None |
File renamed without changes.
47 changes: 47 additions & 0 deletions
47
bpm-ai-core/bpm_ai_core/extractive_qa/question_answering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from pydantic import BaseModel | ||
|
||
from bpm_ai_core.tracing.tracing import Tracing | ||
|
||
|
||
class QAResult(BaseModel): | ||
answer: str | ||
score: float | ||
start_index: int | ||
end_index: int | ||
|
||
|
||
class ExtractiveQA(ABC): | ||
""" | ||
Extractive Question Answering Model | ||
""" | ||
|
||
@abstractmethod | ||
def answer_with_metadata( | ||
self, | ||
context: str, | ||
question: str | ||
) -> QAResult: | ||
pass | ||
|
||
def answer( | ||
self, | ||
context: str, | ||
question: str, | ||
confidence_threshold: float | None = 0.1 | ||
) -> str: | ||
Tracing.tracers().start_span("extractive_qa", inputs={ | ||
"context": context, | ||
"question": question, | ||
"confidence_threshold": confidence_threshold | ||
}) | ||
result = self.answer_with_metadata( | ||
context=context, | ||
question=question | ||
) | ||
Tracing.tracers().end_span(outputs={"result": result.model_dump()}) | ||
# Only return the answer if the score is above the threshold (if given) | ||
return result.answer \ | ||
if not confidence_threshold or result.score > confidence_threshold \ | ||
else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import logging | ||
|
||
from bpm_ai_core.extractive_qa.question_answering import ExtractiveQA, QAResult | ||
|
||
try: | ||
from transformers import pipeline, AutoTokenizer | ||
has_transformers = True | ||
except ImportError: | ||
has_transformers = False | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TransformersExtractiveQA(ExtractiveQA): | ||
""" | ||
Local extractive question answering model based on Huggingface transformers library. | ||
To use, you should have the ``transformers`` python package installed. | ||
""" | ||
|
||
def __init__(self, model: str = "deepset/deberta-v3-large-squad2"): | ||
if not has_transformers: | ||
raise ImportError('transformers is not installed') | ||
self.model = model | ||
|
||
def answer_with_metadata( | ||
self, | ||
context: str, | ||
question: str | ||
) -> QAResult: | ||
qa_model = pipeline("question-answering", model=self.model) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(self.model) | ||
tokens = tokenizer.encode(context + question) | ||
logger.debug(f"Input tokens: {len(tokens)}") | ||
|
||
prediction = qa_model( | ||
question=question, | ||
context=context | ||
) | ||
logger.debug(f"prediction: {prediction}") | ||
|
||
return QAResult( | ||
answer=prediction['answer'], | ||
score=prediction['score'], | ||
start_index=prediction['start'], | ||
end_index=prediction['end'], | ||
) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Tuple | ||
|
||
|
||
class POSTagger(ABC): | ||
""" | ||
Part-of-Speech Tagging Model | ||
""" | ||
|
||
@abstractmethod | ||
def tag(self, text: str) -> list[Tuple[str, str]]: | ||
""" | ||
Returns a list of tuples (token, tag). Example: | ||
[('I', 'PRON'), ('am', 'AUX'), ('30', 'NUM'), ('years', 'NOUN'), ('old', 'ADJ'), ('.', 'PUNCT')] | ||
""" | ||
pass |
Oops, something went wrong.