Skip to content

Commit

Permalink
Add tests for HuggingFaceInferenceAPI models
Browse files Browse the repository at this point in the history
  • Loading branch information
fdalvi committed Sep 10, 2023
1 parent 76b1ba3 commit 33a3bf4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
88 changes: 88 additions & 0 deletions tests/models/test_HuggingFaceInferenceAPI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import unittest
from unittest.mock import patch

from llmebench import Benchmark
from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes


class TestAssetsForHuggingFaceInferenceAPIPrompts(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Load the benchmark assets
benchmark = Benchmark(benchmark_dir="assets")
all_assets = benchmark.find_assets()

# Filter out assets not using the HuggingFaceInferenceAPI model
cls.assets = [
asset
for asset in all_assets
if asset["config"]["model"] in [HuggingFaceInferenceAPIModel]
]

def test_huggingface_inference_api_prompts(self):
"Test if all assets using this model return data in an appropriate format for prompting"

n_shots = 3 # Sample for few shot prompts

for asset in self.assets:
with self.subTest(msg=asset["name"]):
config = asset["config"]
dataset = config["dataset"](**config["dataset_args"])
data_sample = dataset.get_data_sample()
if "fewshot" in config["general_args"]:
prompt = asset["module"].prompt(
data_sample["input"],
[data_sample for _ in range(n_shots)],
)
else:
prompt = asset["module"].prompt(data_sample["input"])

self.assertIsInstance(prompt, dict)
self.assertIn("inputs", prompt)

def test_asset_config(self):
"Test if all assets using this model return data in an appropriate format for prompting"

n_shots = 3 # Sample for few shot prompts

for asset in self.assets:
with self.subTest(msg=asset["name"]):
config = asset["config"]
model_args = config["model_args"]

self.assertIsInstance(model_args, dict)
self.assertIn("task_type", model_args)
self.assertIsInstance(model_args["task_type"], HuggingFaceTaskTypes)
self.assertIn("inference_api_url", model_args)


class TestHuggingFaceInferenceAPIConfig(unittest.TestCase):
def test_huggingface_inference_api_config(self):
"Test if model config parameters passed as arguments are used"
model = HuggingFaceInferenceAPIModel("task", "url", api_token="secret-token")

self.assertEqual(model.api_token, "secret-token")

@patch.dict(
"os.environ",
{
"HUGGINGFACE_API_TOKEN": "secret-token",
},
)
def test_huggingface_inference_api_config_env_var(self):
"Test if model config parameters passed as environment variables are used"
model = HuggingFaceInferenceAPIModel("task", "url")

self.assertEqual(model.api_token, "secret-token")

@patch.dict(
"os.environ",
{
"HUGGINGFACE_API_TOKEN": "secret-token",
},
)
def test_huggingface_inference_api_config_priority(self):
"Test if model config parameters passed as environment variables are used"
model = HuggingFaceInferenceAPIModel("task", "url", api_token="secret-token-2")

self.assertEqual(model.api_token, "secret-token-2")
2 changes: 2 additions & 0 deletions tests/models/test_OpenAIModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def test_openai_prompts(self):
self.assertIn("content", message)
self.assertIsInstance(message["content"], str)


class TestOpenAIConfig(unittest.TestCase):
def test_openai_config(self):
"Test if model config parameters passed as arguments are used"
model = OpenAIModel(
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_Petals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def test_petals_prompts(self):
self.assertIn("prompt", prompt)
self.assertIsInstance(prompt["prompt"], str)


class TestPetalsConfig(unittest.TestCase):
def test_petals_config(self):
"Test if model config parameters passed as arguments are used"
model = PetalsModel(api_url="petals.llmebench.org")
Expand Down

0 comments on commit 33a3bf4

Please sign in to comment.