-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HuggingFaceInferenceAPI model and associated assets (#194)
* Add an implementation of HF model and an example sentiment analysis asset using it * Address review suggestions * Fix formatting * Fix more formatting * Format summary based on task type * Delete NER task The returned format does not include the original text, and the dataset grountruth are labeled sentences which cannot be recovered using the model output alone. * Rename models and other touches * Add docstring to HuggingFaceInferenceAPI * Fix ArSAS asset * Modify model tests to use only ModelBase derived classes * Add explicit check for api token and env var based config * Removed hardcoded env var from assets * Fix missing/spurious imports * Add tests for HuggingFaceInferenceAPI models * Remove dead code --------- Co-authored-by: mhawasly <[email protected]> Co-authored-by: Fahim Imaduddin Dalvi <[email protected]>
- Loading branch information
1 parent
421839f
commit 07f4bf6
Showing
11 changed files
with
406 additions
and
1 deletion.
There are no files selected for viewing
79 changes: 79 additions & 0 deletions
79
assets/ar/MT/AraBench_Ara2Eng_Helsinki_NLP_Opus_MT_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from llmebench.datasets import AraBenchDataset | ||
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes | ||
from llmebench.tasks import MachineTranslationTask | ||
|
||
|
||
def config(): | ||
sets = [ | ||
"bible.test.mgr.0.ma", | ||
"bible.test.mgr.0.tn", | ||
"bible.test.msa.0.ms", | ||
"bible.test.msa.1.ms", | ||
"ldc_web_eg.test.lev.0.jo", | ||
"ldc_web_eg.test.lev.0.ps", | ||
"ldc_web_eg.test.lev.0.sy", | ||
"ldc_web_eg.test.mgr.0.tn", | ||
"ldc_web_eg.test.msa.0.ms", | ||
"ldc_web_eg.test.nil.0.eg", | ||
"ldc_web_lv.test.lev.0.lv", | ||
"madar.test.glf.0.iq", | ||
"madar.test.glf.0.om", | ||
"madar.test.glf.0.qa", | ||
"madar.test.glf.0.sa", | ||
"madar.test.glf.0.ye", | ||
"madar.test.glf.1.iq", | ||
"madar.test.glf.1.sa", | ||
"madar.test.glf.2.iq", | ||
"madar.test.lev.0.jo", | ||
"madar.test.lev.0.lb", | ||
"madar.test.lev.0.pa", | ||
"madar.test.lev.0.sy", | ||
"madar.test.lev.1.jo", | ||
"madar.test.lev.1.sy", | ||
"madar.test.mgr.0.dz", | ||
"madar.test.mgr.0.ly", | ||
"madar.test.mgr.0.ma", | ||
"madar.test.mgr.0.tn", | ||
"madar.test.mgr.1.ly", | ||
"madar.test.mgr.1.ma", | ||
"madar.test.mgr.1.tn", | ||
"madar.test.msa.0.ms", | ||
"madar.test.nil.0.eg", | ||
"madar.test.nil.0.sd", | ||
"madar.test.nil.1.eg", | ||
"madar.test.nil.2.eg", | ||
] | ||
|
||
configs = [] | ||
for testset in sets: | ||
configs.append( | ||
{ | ||
"name": testset, | ||
"config": { | ||
"dataset": AraBenchDataset, | ||
"dataset_args": { | ||
"src": f"{testset}.ar", | ||
"tgt": f"{testset}.en", | ||
}, | ||
"task": MachineTranslationTask, | ||
"task_args": {}, | ||
"model": HuggingFaceInferenceAPIModel, | ||
"model_args": { | ||
"task_type": HuggingFaceTaskTypes.Translation, | ||
"inference_api_url": "https://api-inference.huggingface.co/models/Helsinki-NLP/opus-mt-ar-en", | ||
"max_tries": 5, | ||
}, | ||
"general_args": {"data_path": "data/MT/"}, | ||
}, | ||
} | ||
) | ||
|
||
return configs | ||
|
||
|
||
def prompt(input_sample): | ||
return {"inputs": input_sample} | ||
|
||
|
||
def post_process(response): | ||
return response[0]["translation_text"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from llmebench.datasets import MLQADataset | ||
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes | ||
from llmebench.tasks import QATask | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": MLQADataset, | ||
"dataset_args": {}, | ||
"task": QATask, | ||
"task_args": {}, | ||
"model": HuggingFaceInferenceAPIModel, | ||
"model_args": { | ||
"task_type": HuggingFaceTaskTypes.Question_Answering, | ||
"inference_api_url": "https://api-inference.huggingface.co/models/timpal0l/mdeberta-v3-base-squad2", | ||
"max_tries": 5, | ||
}, | ||
"general_args": { | ||
"data_path": "data/QA/MLQA/test/test-context-ar-question-ar.json" | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
return { | ||
"inputs": { | ||
"context": input_sample["context"], | ||
"question": input_sample["question"], | ||
} | ||
} | ||
|
||
|
||
def post_process(response): | ||
return response["answer"].strip() |
38 changes: 38 additions & 0 deletions
38
assets/ar/semantics/STS/Q2QSim_Intfloat_Multilingual_e5_small_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from llmebench.datasets import Q2QSimDataset | ||
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes | ||
from llmebench.tasks import Q2QSimDetectionTask | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": Q2QSimDataset, | ||
"dataset_args": {}, | ||
"task": Q2QSimDetectionTask, | ||
"task_args": {}, | ||
"model": HuggingFaceInferenceAPIModel, | ||
"model_args": { | ||
"task_type": HuggingFaceTaskTypes.Sentence_Similarity, | ||
"inference_api_url": "https://api-inference.huggingface.co/models/intfloat/multilingual-e5-small", | ||
"max_tries": 5, | ||
}, | ||
"general_args": { | ||
"data_path": "data/STS/nsurl-2019-task8/test.tsv", | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
q1, q2 = input_sample.split("\t") | ||
|
||
return {"inputs": {"source_sentence": q1, "sentences": [q2]}} | ||
|
||
|
||
def post_process(response): | ||
if response[0] > 0.7: | ||
pred_label = "1" | ||
elif response[0] < 0.3: | ||
pred_label = "0" | ||
else: | ||
pred_label = None | ||
|
||
return pred_label |
31 changes: 31 additions & 0 deletions
31
assets/ar/sentiment_emotion_others/sentiment/ArSAS_Camelbert_da_sentiment_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from llmebench.datasets import ArSASDataset | ||
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes | ||
from llmebench.tasks import SentimentTask | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": ArSASDataset, | ||
"dataset_args": {}, | ||
"task": SentimentTask, | ||
"task_args": {}, | ||
"model": HuggingFaceInferenceAPIModel, | ||
"model_args": { | ||
"task_type": HuggingFaceTaskTypes.Text_Classification, | ||
"inference_api_url": "https://api-inference.huggingface.co/models/CAMeL-Lab/bert-base-arabic-camelbert-da-sentiment", | ||
"max_tries": 5, | ||
}, | ||
"general_args": { | ||
"data_path": "data/sentiment_emotion_others/sentiment/ArSAS-test.txt" | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
return {"inputs": input_sample} | ||
|
||
|
||
def post_process(response): | ||
scores = [(c["label"], c["score"]) for c in response[0]] | ||
label = sorted(scores, key=lambda x: x[1])[-1][0] | ||
return label[0].upper() + label[1:].lower() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Sample env file for using a model using HuggingFace's Inference API | ||
HUGGINGFACE_API_TOKEN="..." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import json | ||
import os | ||
import time | ||
|
||
from enum import Enum | ||
|
||
import requests | ||
|
||
from llmebench.models.model_base import ModelBase | ||
|
||
HuggingFaceTaskTypes = Enum( | ||
"HuggingFaceTaskTypes", | ||
[ | ||
"Summarization", | ||
"Sentence_Similarity", | ||
"Text_Generation", | ||
"Text2Text_Generation", | ||
"Translation", | ||
"Feature_Extraction", | ||
"Fill_Mask", | ||
"Question_Answering", | ||
"Table_Question_Answering", | ||
"Text_Classification", | ||
"Token_Classification", | ||
"Named_Entity_Recognition", | ||
"Zero_Shot_Classification", | ||
"Conversational", | ||
], | ||
) | ||
|
||
|
||
class HuggingFaceModelLoadingError(Exception): | ||
def __init__(self, failure_message): | ||
self.failure_message = failure_message | ||
|
||
def __str__(self): | ||
return f"HuggingFace model loading -- \n {self.failure_message}" | ||
|
||
|
||
class HuggingFaceInferenceAPIModel(ModelBase): | ||
"""An interface to HuggingFace Inference API | ||
Args: | ||
task_type: one of Summarization, Sentence_Similarity, Text_Generation, Text2Text_Generation, Translation, | ||
Feature_Extraction, Fill_Mask, Question_Answering, Table_Question_Answering, Text_Classification, | ||
Token_Classification, Named_Entity_Recognition, Zero_Shot_Classification, Conversational as found on | ||
HuggingFace model's page | ||
inference_api_url: the URL to the particular model, as found in the Deploy > Inference API menu in the model's page | ||
api_token: HuggingFace API access key (can also be read from enviroment variable HUGGINGFACE_API_TOKEN) | ||
""" | ||
|
||
def __init__(self, task_type, inference_api_url, api_token=None, **kwargs): | ||
self.task_type = task_type | ||
self.inference_api_url = inference_api_url | ||
self.api_token = api_token or os.getenv("HUGGINGFACE_API_TOKEN") | ||
|
||
if self.api_token is None: | ||
raise Exception( | ||
"API token must be provided as model config or environment variable (`HUGGINGFACE_API_TOKEN`)" | ||
) | ||
|
||
super(HuggingFaceInferenceAPIModel, self).__init__( | ||
retry_exceptions=(TimeoutError, HuggingFaceModelLoadingError), **kwargs | ||
) | ||
|
||
def prompt(self, processed_input): | ||
headers = {"Authorization": f"Bearer {self.api_token}"} | ||
data = json.dumps(processed_input) | ||
response = requests.request( | ||
"POST", self.inference_api_url, headers=headers, data=data | ||
) | ||
if not response.ok: | ||
if response.status_code == 503: # model loading | ||
raise HuggingFaceModelLoadingError(response.reason) | ||
else: | ||
raise Exception(response.reason) | ||
return response.json() | ||
|
||
def summarize_response(self, response): | ||
""" | ||
This method will attempt to interpret the output based on the task type. Otherwise, it returns the response object as is. | ||
""" | ||
output_types = { | ||
HuggingFaceTaskTypes.Summarization: str, | ||
HuggingFaceTaskTypes.Sentence_Similarity: list, | ||
HuggingFaceTaskTypes.Text_Generation: str, | ||
HuggingFaceTaskTypes.Text2Text_Generation: str, | ||
HuggingFaceTaskTypes.Feature_Extraction: list, | ||
} | ||
output_dict_summary_keys = { | ||
HuggingFaceTaskTypes.Fill_Mask: ["token_str"], | ||
HuggingFaceTaskTypes.Question_Answering: ["answer"], | ||
HuggingFaceTaskTypes.Table_Question_Answering: ["answer"], | ||
HuggingFaceTaskTypes.Text_Classification: ["label", "score"], | ||
HuggingFaceTaskTypes.Token_Classification: ["entity_group", "word"], | ||
HuggingFaceTaskTypes.Named_Entity_Recognition: ["entity_group", "word"], | ||
HuggingFaceTaskTypes.Zero_Shot_Classification: ["scores"], | ||
HuggingFaceTaskTypes.Conversational: ["generated_text"], | ||
HuggingFaceTaskTypes.Translation: ["translation_text"], | ||
} | ||
|
||
output_type = output_types.get(self.task_type, dict) | ||
|
||
try: | ||
if output_type == list: | ||
return ", ".join([str(s) for s in response]) | ||
|
||
if isinstance(response, list) and len(response) == 1: | ||
response = response[0] | ||
|
||
if output_type == dict: | ||
keys = output_dict_summary_keys[self.task_type] | ||
if isinstance(response, list): # list of dictionaries | ||
return ", ".join( | ||
[":".join([str(d[key]) for key in keys]) for d in response] | ||
) | ||
else: | ||
return ":".join([str(response[key]) for key in keys]) | ||
else: | ||
return response | ||
except Exception: | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .HuggingFaceInferenceAPI import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes | ||
from .OpenAI import LegacyOpenAIModel, OpenAIModel | ||
from .Petals import PetalsModel |
Oops, something went wrong.