Skip to content

Commit

Permalink
Add HuggingFaceInferenceAPI model and associated assets (#194)
Browse files Browse the repository at this point in the history
* Add an implementation of HF model and an example sentiment analysis asset using it

* Address review suggestions

* Fix formatting

* Fix more formatting

* Format summary based on task type

* Delete NER task

The returned format does not include the original text, and the dataset grountruth are labeled sentences which cannot be recovered using the model output alone.

* Rename models and other touches

* Add docstring to HuggingFaceInferenceAPI

* Fix ArSAS asset

* Modify model tests to use only ModelBase derived classes

* Add explicit check for api token and env var based config

* Removed hardcoded env var from assets

* Fix missing/spurious imports

* Add tests for HuggingFaceInferenceAPI models

* Remove dead code

---------

Co-authored-by: mhawasly <[email protected]>
Co-authored-by: Fahim Imaduddin Dalvi <[email protected]>
  • Loading branch information
3 people authored Sep 10, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 421839f commit 07f4bf6
Showing 11 changed files with 406 additions and 1 deletion.
79 changes: 79 additions & 0 deletions assets/ar/MT/AraBench_Ara2Eng_Helsinki_NLP_Opus_MT_ZeroShot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from llmebench.datasets import AraBenchDataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import MachineTranslationTask


def config():
sets = [
"bible.test.mgr.0.ma",
"bible.test.mgr.0.tn",
"bible.test.msa.0.ms",
"bible.test.msa.1.ms",
"ldc_web_eg.test.lev.0.jo",
"ldc_web_eg.test.lev.0.ps",
"ldc_web_eg.test.lev.0.sy",
"ldc_web_eg.test.mgr.0.tn",
"ldc_web_eg.test.msa.0.ms",
"ldc_web_eg.test.nil.0.eg",
"ldc_web_lv.test.lev.0.lv",
"madar.test.glf.0.iq",
"madar.test.glf.0.om",
"madar.test.glf.0.qa",
"madar.test.glf.0.sa",
"madar.test.glf.0.ye",
"madar.test.glf.1.iq",
"madar.test.glf.1.sa",
"madar.test.glf.2.iq",
"madar.test.lev.0.jo",
"madar.test.lev.0.lb",
"madar.test.lev.0.pa",
"madar.test.lev.0.sy",
"madar.test.lev.1.jo",
"madar.test.lev.1.sy",
"madar.test.mgr.0.dz",
"madar.test.mgr.0.ly",
"madar.test.mgr.0.ma",
"madar.test.mgr.0.tn",
"madar.test.mgr.1.ly",
"madar.test.mgr.1.ma",
"madar.test.mgr.1.tn",
"madar.test.msa.0.ms",
"madar.test.nil.0.eg",
"madar.test.nil.0.sd",
"madar.test.nil.1.eg",
"madar.test.nil.2.eg",
]

configs = []
for testset in sets:
configs.append(
{
"name": testset,
"config": {
"dataset": AraBenchDataset,
"dataset_args": {
"src": f"{testset}.ar",
"tgt": f"{testset}.en",
},
"task": MachineTranslationTask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Translation,
"inference_api_url": "https://api-inference.huggingface.co/models/Helsinki-NLP/opus-mt-ar-en",
"max_tries": 5,
},
"general_args": {"data_path": "data/MT/"},
},
}
)

return configs


def prompt(input_sample):
return {"inputs": input_sample}


def post_process(response):
return response[0]["translation_text"]
34 changes: 34 additions & 0 deletions assets/ar/QA/MLQA_mdeberta_v3_base_squad2_ZeroShot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from llmebench.datasets import MLQADataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import QATask


def config():
return {
"dataset": MLQADataset,
"dataset_args": {},
"task": QATask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Question_Answering,
"inference_api_url": "https://api-inference.huggingface.co/models/timpal0l/mdeberta-v3-base-squad2",
"max_tries": 5,
},
"general_args": {
"data_path": "data/QA/MLQA/test/test-context-ar-question-ar.json"
},
}


def prompt(input_sample):
return {
"inputs": {
"context": input_sample["context"],
"question": input_sample["question"],
}
}


def post_process(response):
return response["answer"].strip()
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from llmebench.datasets import Q2QSimDataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import Q2QSimDetectionTask


def config():
return {
"dataset": Q2QSimDataset,
"dataset_args": {},
"task": Q2QSimDetectionTask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Sentence_Similarity,
"inference_api_url": "https://api-inference.huggingface.co/models/intfloat/multilingual-e5-small",
"max_tries": 5,
},
"general_args": {
"data_path": "data/STS/nsurl-2019-task8/test.tsv",
},
}


def prompt(input_sample):
q1, q2 = input_sample.split("\t")

return {"inputs": {"source_sentence": q1, "sentences": [q2]}}


def post_process(response):
if response[0] > 0.7:
pred_label = "1"
elif response[0] < 0.3:
pred_label = "0"
else:
pred_label = None

return pred_label
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from llmebench.datasets import ArSASDataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import SentimentTask


def config():
return {
"dataset": ArSASDataset,
"dataset_args": {},
"task": SentimentTask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Text_Classification,
"inference_api_url": "https://api-inference.huggingface.co/models/CAMeL-Lab/bert-base-arabic-camelbert-da-sentiment",
"max_tries": 5,
},
"general_args": {
"data_path": "data/sentiment_emotion_others/sentiment/ArSAS-test.txt"
},
}


def prompt(input_sample):
return {"inputs": input_sample}


def post_process(response):
scores = [(c["label"], c["score"]) for c in response[0]]
label = sorted(scores, key=lambda x: x[1])[-1][0]
return label[0].upper() + label[1:].lower()
2 changes: 2 additions & 0 deletions envs/huggingface.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Sample env file for using a model using HuggingFace's Inference API
HUGGINGFACE_API_TOKEN="..."
122 changes: 122 additions & 0 deletions llmebench/models/HuggingFaceInferenceAPI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import json
import os
import time

from enum import Enum

import requests

from llmebench.models.model_base import ModelBase

HuggingFaceTaskTypes = Enum(
"HuggingFaceTaskTypes",
[
"Summarization",
"Sentence_Similarity",
"Text_Generation",
"Text2Text_Generation",
"Translation",
"Feature_Extraction",
"Fill_Mask",
"Question_Answering",
"Table_Question_Answering",
"Text_Classification",
"Token_Classification",
"Named_Entity_Recognition",
"Zero_Shot_Classification",
"Conversational",
],
)


class HuggingFaceModelLoadingError(Exception):
def __init__(self, failure_message):
self.failure_message = failure_message

def __str__(self):
return f"HuggingFace model loading -- \n {self.failure_message}"


class HuggingFaceInferenceAPIModel(ModelBase):
"""An interface to HuggingFace Inference API
Args:
task_type: one of Summarization, Sentence_Similarity, Text_Generation, Text2Text_Generation, Translation,
Feature_Extraction, Fill_Mask, Question_Answering, Table_Question_Answering, Text_Classification,
Token_Classification, Named_Entity_Recognition, Zero_Shot_Classification, Conversational as found on
HuggingFace model's page
inference_api_url: the URL to the particular model, as found in the Deploy > Inference API menu in the model's page
api_token: HuggingFace API access key (can also be read from enviroment variable HUGGINGFACE_API_TOKEN)
"""

def __init__(self, task_type, inference_api_url, api_token=None, **kwargs):
self.task_type = task_type
self.inference_api_url = inference_api_url
self.api_token = api_token or os.getenv("HUGGINGFACE_API_TOKEN")

if self.api_token is None:
raise Exception(
"API token must be provided as model config or environment variable (`HUGGINGFACE_API_TOKEN`)"
)

super(HuggingFaceInferenceAPIModel, self).__init__(
retry_exceptions=(TimeoutError, HuggingFaceModelLoadingError), **kwargs
)

def prompt(self, processed_input):
headers = {"Authorization": f"Bearer {self.api_token}"}
data = json.dumps(processed_input)
response = requests.request(
"POST", self.inference_api_url, headers=headers, data=data
)
if not response.ok:
if response.status_code == 503: # model loading
raise HuggingFaceModelLoadingError(response.reason)
else:
raise Exception(response.reason)
return response.json()

def summarize_response(self, response):
"""
This method will attempt to interpret the output based on the task type. Otherwise, it returns the response object as is.
"""
output_types = {
HuggingFaceTaskTypes.Summarization: str,
HuggingFaceTaskTypes.Sentence_Similarity: list,
HuggingFaceTaskTypes.Text_Generation: str,
HuggingFaceTaskTypes.Text2Text_Generation: str,
HuggingFaceTaskTypes.Feature_Extraction: list,
}
output_dict_summary_keys = {
HuggingFaceTaskTypes.Fill_Mask: ["token_str"],
HuggingFaceTaskTypes.Question_Answering: ["answer"],
HuggingFaceTaskTypes.Table_Question_Answering: ["answer"],
HuggingFaceTaskTypes.Text_Classification: ["label", "score"],
HuggingFaceTaskTypes.Token_Classification: ["entity_group", "word"],
HuggingFaceTaskTypes.Named_Entity_Recognition: ["entity_group", "word"],
HuggingFaceTaskTypes.Zero_Shot_Classification: ["scores"],
HuggingFaceTaskTypes.Conversational: ["generated_text"],
HuggingFaceTaskTypes.Translation: ["translation_text"],
}

output_type = output_types.get(self.task_type, dict)

try:
if output_type == list:
return ", ".join([str(s) for s in response])

if isinstance(response, list) and len(response) == 1:
response = response[0]

if output_type == dict:
keys = output_dict_summary_keys[self.task_type]
if isinstance(response, list): # list of dictionaries
return ", ".join(
[":".join([str(d[key]) for key in keys]) for d in response]
)
else:
return ":".join([str(response[key]) for key in keys])
else:
return response
except Exception:
return response
1 change: 1 addition & 0 deletions llmebench/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .HuggingFaceInferenceAPI import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from .OpenAI import LegacyOpenAIModel, OpenAIModel
from .Petals import PetalsModel
Loading

0 comments on commit 07f4bf6

Please sign in to comment.