From 33a3bf4ee188f7d031fb2a1d747de144becff6b2 Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Sun, 10 Sep 2023 13:54:43 +0300 Subject: [PATCH] Add tests for HuggingFaceInferenceAPI models --- tests/models/test_HuggingFaceInferenceAPI.py | 88 ++++++++++++++++++++ tests/models/test_OpenAIModel.py | 2 + tests/models/test_Petals.py | 2 + 3 files changed, 92 insertions(+) create mode 100644 tests/models/test_HuggingFaceInferenceAPI.py diff --git a/tests/models/test_HuggingFaceInferenceAPI.py b/tests/models/test_HuggingFaceInferenceAPI.py new file mode 100644 index 00000000..4dc5cd02 --- /dev/null +++ b/tests/models/test_HuggingFaceInferenceAPI.py @@ -0,0 +1,88 @@ +import unittest +from unittest.mock import patch + +from llmebench import Benchmark +from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes + + +class TestAssetsForHuggingFaceInferenceAPIPrompts(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Load the benchmark assets + benchmark = Benchmark(benchmark_dir="assets") + all_assets = benchmark.find_assets() + + # Filter out assets not using the HuggingFaceInferenceAPI model + cls.assets = [ + asset + for asset in all_assets + if asset["config"]["model"] in [HuggingFaceInferenceAPIModel] + ] + + def test_huggingface_inference_api_prompts(self): + "Test if all assets using this model return data in an appropriate format for prompting" + + n_shots = 3 # Sample for few shot prompts + + for asset in self.assets: + with self.subTest(msg=asset["name"]): + config = asset["config"] + dataset = config["dataset"](**config["dataset_args"]) + data_sample = dataset.get_data_sample() + if "fewshot" in config["general_args"]: + prompt = asset["module"].prompt( + data_sample["input"], + [data_sample for _ in range(n_shots)], + ) + else: + prompt = asset["module"].prompt(data_sample["input"]) + + self.assertIsInstance(prompt, dict) + self.assertIn("inputs", prompt) + + def test_asset_config(self): + "Test if all assets using this model return data in an appropriate format for prompting" + + n_shots = 3 # Sample for few shot prompts + + for asset in self.assets: + with self.subTest(msg=asset["name"]): + config = asset["config"] + model_args = config["model_args"] + + self.assertIsInstance(model_args, dict) + self.assertIn("task_type", model_args) + self.assertIsInstance(model_args["task_type"], HuggingFaceTaskTypes) + self.assertIn("inference_api_url", model_args) + + +class TestHuggingFaceInferenceAPIConfig(unittest.TestCase): + def test_huggingface_inference_api_config(self): + "Test if model config parameters passed as arguments are used" + model = HuggingFaceInferenceAPIModel("task", "url", api_token="secret-token") + + self.assertEqual(model.api_token, "secret-token") + + @patch.dict( + "os.environ", + { + "HUGGINGFACE_API_TOKEN": "secret-token", + }, + ) + def test_huggingface_inference_api_config_env_var(self): + "Test if model config parameters passed as environment variables are used" + model = HuggingFaceInferenceAPIModel("task", "url") + + self.assertEqual(model.api_token, "secret-token") + + @patch.dict( + "os.environ", + { + "HUGGINGFACE_API_TOKEN": "secret-token", + }, + ) + def test_huggingface_inference_api_config_priority(self): + "Test if model config parameters passed as environment variables are used" + model = HuggingFaceInferenceAPIModel("task", "url", api_token="secret-token-2") + + self.assertEqual(model.api_token, "secret-token-2") diff --git a/tests/models/test_OpenAIModel.py b/tests/models/test_OpenAIModel.py index d4a7a97e..7f2ee1ba 100644 --- a/tests/models/test_OpenAIModel.py +++ b/tests/models/test_OpenAIModel.py @@ -46,6 +46,8 @@ def test_openai_prompts(self): self.assertIn("content", message) self.assertIsInstance(message["content"], str) + +class TestOpenAIConfig(unittest.TestCase): def test_openai_config(self): "Test if model config parameters passed as arguments are used" model = OpenAIModel( diff --git a/tests/models/test_Petals.py b/tests/models/test_Petals.py index 9920b7b1..e661d227 100644 --- a/tests/models/test_Petals.py +++ b/tests/models/test_Petals.py @@ -39,6 +39,8 @@ def test_petals_prompts(self): self.assertIn("prompt", prompt) self.assertIsInstance(prompt["prompt"], str) + +class TestPetalsConfig(unittest.TestCase): def test_petals_config(self): "Test if model config parameters passed as arguments are used" model = PetalsModel(api_url="petals.llmebench.org")