Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an implementation of HF model and an example sentiment analysis a… #194

Merged
merged 16 commits into from
Sep 10, 2023
Merged
81 changes: 81 additions & 0 deletions assets/ar/MT/AraBench_Ara2Eng_Helsinki_NLP_Opus_MT_ZeroShot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

from llmebench.datasets import AraBenchDataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import MachineTranslationTask


def config():
sets = [
"bible.test.mgr.0.ma",
"bible.test.mgr.0.tn",
"bible.test.msa.0.ms",
"bible.test.msa.1.ms",
"ldc_web_eg.test.lev.0.jo",
"ldc_web_eg.test.lev.0.ps",
"ldc_web_eg.test.lev.0.sy",
"ldc_web_eg.test.mgr.0.tn",
"ldc_web_eg.test.msa.0.ms",
"ldc_web_eg.test.nil.0.eg",
"ldc_web_lv.test.lev.0.lv",
"madar.test.glf.0.iq",
"madar.test.glf.0.om",
"madar.test.glf.0.qa",
"madar.test.glf.0.sa",
"madar.test.glf.0.ye",
"madar.test.glf.1.iq",
"madar.test.glf.1.sa",
"madar.test.glf.2.iq",
"madar.test.lev.0.jo",
"madar.test.lev.0.lb",
"madar.test.lev.0.pa",
"madar.test.lev.0.sy",
"madar.test.lev.1.jo",
"madar.test.lev.1.sy",
"madar.test.mgr.0.dz",
"madar.test.mgr.0.ly",
"madar.test.mgr.0.ma",
"madar.test.mgr.0.tn",
"madar.test.mgr.1.ly",
"madar.test.mgr.1.ma",
"madar.test.mgr.1.tn",
"madar.test.msa.0.ms",
"madar.test.nil.0.eg",
"madar.test.nil.0.sd",
"madar.test.nil.1.eg",
"madar.test.nil.2.eg",
]

configs = []
for testset in sets:
configs.append(
{
"name": testset,
"config": {
"dataset": AraBenchDataset,
"dataset_args": {
"src": f"{testset}.ar",
"tgt": f"{testset}.en",
},
"task": MachineTranslationTask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Translation,
"inference_api_url": "https://api-inference.huggingface.co/models/Helsinki-NLP/opus-mt-ar-en",
"max_tries": 5,
},
"general_args": {"data_path": "data/MT/"},
},
}
)

return configs


def prompt(input_sample):
return {"inputs": input_sample}


def post_process(response):
return response[0]["translation_text"]
36 changes: 36 additions & 0 deletions assets/ar/QA/MLQA_mdeberta_v3_base_squad2_ZeroShot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

from llmebench.datasets import MLQADataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import QATask


def config():
return {
"dataset": MLQADataset,
"dataset_args": {},
"task": QATask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Question_Answering,
"inference_api_url": "https://api-inference.huggingface.co/models/timpal0l/mdeberta-v3-base-squad2",
"max_tries": 5,
},
"general_args": {
"data_path": "data/QA/MLQA/test/test-context-ar-question-ar.json"
},
}


def prompt(input_sample):
return {
"inputs": {
"context": input_sample["context"],
"question": input_sample["question"],
}
}


def post_process(response):
return response["answer"].strip()
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import re

from llmebench.datasets import Q2QSimDataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import Q2QSimDetectionTask


def config():
return {
"dataset": Q2QSimDataset,
"dataset_args": {},
"task": Q2QSimDetectionTask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Sentence_Similarity,
"inference_api_url": "https://api-inference.huggingface.co/models/intfloat/multilingual-e5-small",
"max_tries": 5,
},
"general_args": {
"data_path": "data/STS/nsurl-2019-task8/test.tsv",
},
}


def prompt(input_sample):
q1, q2 = input_sample.split("\t")

return {"inputs": {"source_sentence": q1, "sentences": [q2]}}


def post_process(response):
if response[0] > 0.7:
pred_label = "1"
elif response[0] < 0.3:
pred_label = "0"
else:
pred_label = None

return pred_label
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

from llmebench.datasets import ArSASDataset
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from llmebench.tasks import SentimentTask


def config():
return {
"dataset": ArSASDataset,
"dataset_args": {},
"task": SentimentTask,
"task_args": {},
"model": HuggingFaceInferenceAPIModel,
"model_args": {
"task_type": HuggingFaceTaskTypes.Text_Classification,
"inference_api_url": "https://api-inference.huggingface.co/models/CAMeL-Lab/bert-base-arabic-camelbert-da-sentiment",
"max_tries": 5,
},
"general_args": {
"data_path": "data/sentiment_emotion_others/sentiment/ArSAS-test.txt"
},
}


def prompt(input_sample):
return {"inputs": input_sample}


def post_process(response):
scores = [(c["label"], c["score"]) for c in response[0]]
label = sorted(scores, key=lambda x: x[1])[-1][0]
return label[0].upper() + label[1:].lower()
2 changes: 2 additions & 0 deletions envs/huggingface.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Sample env file for using a model using HuggingFace's Inference API
HUGGINGFACE_API_TOKEN="..."
120 changes: 120 additions & 0 deletions llmebench/models/HuggingFaceInferenceAPI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json
import time
from enum import Enum

import requests

from llmebench.models.model_base import ModelBase

HuggingFaceTaskTypes = Enum(
"HuggingFaceTaskTypes",
[
"Summarization",
"Sentence_Similarity",
"Text_Generation",
"Text2Text_Generation",
"Translation",
"Feature_Extraction",
"Fill_Mask",
"Question_Answering",
"Table_Question_Answering",
"Text_Classification",
"Token_Classification",
"Named_Entity_Recognition",
"Zero_Shot_Classification",
"Conversational",
],
)


class HuggingFaceModelLoadingError(Exception):
def __init__(self, failure_message):
self.failure_message = failure_message

def __str__(self):
return f"HuggingFace model loading -- \n {self.failure_message}"


class HuggingFaceInferenceAPIModel(ModelBase):
"""An interface to HuggingFace Inference API

Args:
task_type: one of Summarization, Sentence_Similarity, Text_Generation, Text2Text_Generation, Translation,
Feature_Extraction, Fill_Mask, Question_Answering, Table_Question_Answering, Text_Classification,
Token_Classification, Named_Entity_Recognition, Zero_Shot_Classification, Conversational as found on
HuggingFace model's page
inference_api_url: the URL to the particular model, as found in the Deploy > Inference API menu in the model's page
api_token: HuggingFace API access key (can also be read from enviroment variable HUGGINGFACE_API_TOKEN)
"""

def __init__(self, task_type, inference_api_url, api_token=None, **kwargs):
self.task_type = task_type
self.inference_api_url = inference_api_url
self.api_token = api_token or os.getenv("HUGGINGFACE_API_TOKEN")

if self.api_token is None:
raise Exception(
"API token must be provided as model config or environment variable (`HUGGINGFACE_API_TOKEN`)"
)

super(HuggingFaceInferenceAPIModel, self).__init__(
retry_exceptions=(TimeoutError, HuggingFaceModelLoadingError), **kwargs
)

def prompt(self, processed_input):
headers = {"Authorization": f"Bearer {self.api_token}"}
data = json.dumps(processed_input)
response = requests.request(
"POST", self.inference_api_url, headers=headers, data=data
)
if not response.ok:
if response.status_code == 503: # model loading
raise HuggingFaceModelLoadingError(response.reason)
else:
raise Exception(response.reason)
return response.json()

def summarize_response(self, response):
"""
This method will attempt to interpret the output based on the task type. Otherwise, it returns the response object as is.
"""
output_types = {
HuggingFaceTaskTypes.Summarization: str,
HuggingFaceTaskTypes.Sentence_Similarity: list,
HuggingFaceTaskTypes.Text_Generation: str,
HuggingFaceTaskTypes.Text2Text_Generation: str,
HuggingFaceTaskTypes.Feature_Extraction: list,
}
output_dict_summary_keys = {
HuggingFaceTaskTypes.Fill_Mask: ["token_str"],
HuggingFaceTaskTypes.Question_Answering: ["answer"],
HuggingFaceTaskTypes.Table_Question_Answering: ["answer"],
HuggingFaceTaskTypes.Text_Classification: ["label", "score"],
HuggingFaceTaskTypes.Token_Classification: ["entity_group", "word"],
HuggingFaceTaskTypes.Named_Entity_Recognition: ["entity_group", "word"],
HuggingFaceTaskTypes.Zero_Shot_Classification: ["scores"],
HuggingFaceTaskTypes.Conversational: ["generated_text"],
HuggingFaceTaskTypes.Translation: ["translation_text"],
}

output_type = output_types.get(self.task_type, dict)

try:
if output_type == list:
return ", ".join([str(s) for s in response])

if isinstance(response, list) and len(response) == 1:
response = response[0]

if output_type == dict:
keys = output_dict_summary_keys[self.task_type]
if isinstance(response, list): # list of dictionaries
return ", ".join(
[":".join([str(d[key]) for key in keys]) for d in response]
)
else:
return ":".join([str(response[key]) for key in keys])
else:
return response
except Exception:
return response
1 change: 1 addition & 0 deletions llmebench/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .HuggingFaceInferenceAPI import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from .OpenAI import LegacyOpenAIModel, OpenAIModel
from .Petals import PetalsModel
10 changes: 9 additions & 1 deletion tests/models/test_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import llmebench.models as models

from llmebench.models.model_base import ModelBase

from tests.utils import base_class_constructor_checker


Expand All @@ -13,7 +15,13 @@ class TestModelImplementation(unittest.TestCase):
def setUpClass(cls):
# Search for all implemented models
framework_dir = Path("llmebench")
cls.models = set([m[1] for m in inspect.getmembers(models, inspect.isclass)])
cls.models = set(
[
m[1]
for m in inspect.getmembers(models, inspect.isclass)
if issubclass(m[1], ModelBase)
]
)

def test_base_constructor(self):
"Test if all models also call the base class constructor"
Expand Down
Loading