diff --git a/assets/ar/MT/AraBench_Ara2Eng_Helsinki_NLP_Opus_MT_ZeroShot.py b/assets/ar/MT/AraBench_ar2en_Helsinki_NLP_Opus_MT_ZeroShot.py similarity index 100% rename from assets/ar/MT/AraBench_Ara2Eng_Helsinki_NLP_Opus_MT_ZeroShot.py rename to assets/ar/MT/AraBench_ar2en_Helsinki_NLP_Opus_MT_ZeroShot.py diff --git a/assets/ar/MT/AraBench_ar2en_Jais_ZeroShot.py b/assets/ar/MT/AraBench_ar2en_Jais_ZeroShot.py new file mode 100644 index 00000000..9cb974fd --- /dev/null +++ b/assets/ar/MT/AraBench_ar2en_Jais_ZeroShot.py @@ -0,0 +1,46 @@ +from llmebench.datasets import AraBenchDataset +from llmebench.models import FastChatModel +from llmebench.tasks import MachineTranslationTask + + +def metadata(): + return { + "author": "Arabic Language Technologies, QCRI, HBKU", + "model": "jais-13b-chat", + "description": "Locally hosted Jais Chat 13b model using FastChat.", + } + + +def config(): + return { + "dataset": AraBenchDataset, + "dataset_args": { + "src_lang": "ar", + "tgt_lang": "en", + }, + "task": MachineTranslationTask, + "model": FastChatModel, + "model_args": { + "max_tries": 3, + }, + } + + +def prompt(input_sample): + return [ + { + "role": "system", + "content": "You are an expert translator specialized in translating texts from Arabic to English. You are concise as you only output the translation of the text without any illustrations or extra details", + }, + { + "role": "user", + "content": f"Translate the following text to English.\nText: {input_sample}\nTranslation: ", + }, + ] + + +def post_process(response): + response = response["choices"][0]["message"]["content"] + response = response.replace('"', "") + response = response.strip() + return response diff --git a/envs/fastchat.env b/envs/fastchat.env new file mode 100644 index 00000000..6fe9ec30 --- /dev/null +++ b/envs/fastchat.env @@ -0,0 +1,4 @@ +# Sample env file for using a model hosted using FastChat +FASTCHAT_MODEL="..." +FASTCHAT_API_BASE="..." +FASTCHAT_API_KEY="..." \ No newline at end of file diff --git a/llmebench/models/FastChat.py b/llmebench/models/FastChat.py new file mode 100644 index 00000000..ebd75b3b --- /dev/null +++ b/llmebench/models/FastChat.py @@ -0,0 +1,50 @@ +import os + +from llmebench.models.OpenAI import OpenAIModel + + +class FastChatModel(OpenAIModel): + """ + FastChat Model interface. Can be used for models hosted using FastChat + https://github.com/lm-sys/FastChat + + Accepts all arguments used by `OpenAIModel`, and overrides the arguments listed + below with FastChat-specific variables. + + See the [https://github.com/lm-sys/FastChat/blob/main/docs/model_support.md](model_support) + page in FastChat's documentation for supported models and instructions on extending + to custom models. + + Arguments + --------- + api_base : str + URL where the model is hosted. If not provided, the implementation will look at + environment variable `FASTCHAT_API_BASE` + api_key : str + Authentication token for the API. If not provided, the implementation will derive it + from environment variable `FASTCHAT_API_KEY` + model_name : str + Name of the model to use. If not provided, the implementation will derive it from + environment variable `FASTCHAT_MODEL` + """ + + def __init__(self, api_base=None, api_key=None, model_name=None, **kwargs): + api_base = api_base or os.getenv("FASTCHAT_API_BASE") + api_key = api_key or os.getenv("FASTCHAT_API_KEY") + model_name = model_name or os.getenv("FASTCHAT_MODEL") + if api_base is None: + raise Exception( + "API url must be provided as model config or environment variable (`FASTCHAT_API_BASE`)" + ) + if api_key is None: + raise Exception( + "API url must be provided as model config or environment variable (`FASTCHAT_API_KEY`)" + ) + if model_name is None: + raise Exception( + "API url must be provided as model config or environment variable (`FASTCHAT_MODEL`)" + ) + # checks for valid config settings) + super(FastChatModel, self).__init__( + api_base=api_base, api_key=api_key, model_name=model_name, **kwargs + ) diff --git a/llmebench/models/__init__.py b/llmebench/models/__init__.py index facffec8..c5a10bb5 100644 --- a/llmebench/models/__init__.py +++ b/llmebench/models/__init__.py @@ -1,3 +1,4 @@ +from .FastChat import FastChatModel from .HuggingFaceInferenceAPI import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes from .OpenAI import LegacyOpenAIModel, OpenAIModel from .Petals import PetalsModel diff --git a/tests/models/test_FastChatModel.py b/tests/models/test_FastChatModel.py new file mode 100644 index 00000000..bde0f2fe --- /dev/null +++ b/tests/models/test_FastChatModel.py @@ -0,0 +1,76 @@ +import unittest +from unittest.mock import patch + +import openai + +from llmebench import Benchmark +from llmebench.models import FastChatModel + +from tests.models.test_OpenAIModel import TestAssetsForOpenAIPrompts + + +class TestAssetsForFastChatPrompts(TestAssetsForOpenAIPrompts): + @classmethod + def setUpClass(cls): + # Load the benchmark assets + benchmark = Benchmark(benchmark_dir="assets") + all_assets = benchmark.find_assets() + + # Filter out assets not using the GPT model + cls.assets = [ + asset for asset in all_assets if asset["config"]["model"] in [FastChatModel] + ] + + def test_fastchat_prompts(self): + "Test if all assets using this model return data in an appropriate format for prompting" + + self.test_openai_prompts() + + +class TestFastChatConfig(unittest.TestCase): + def test_fastchat_config(self): + "Test if model config parameters passed as arguments are used" + model = FastChatModel( + api_base="llmebench.qcri.org", + api_key="secret-key", + model_name="private-model", + ) + + self.assertEqual(openai.api_type, "openai") + self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual(openai.api_key, "secret-key") + self.assertEqual(model.model_params["model"], "private-model") + + @patch.dict( + "os.environ", + { + "FASTCHAT_API_BASE": "llmebench.qcri.org", + "FASTCHAT_API_KEY": "secret-key", + "FASTCHAT_MODEL": "private-model", + }, + ) + def test_fastchat_config_env_var(self): + "Test if model config parameters passed as environment variables are used" + model = FastChatModel() + + self.assertEqual(openai.api_type, "openai") + self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual(openai.api_key, "secret-key") + self.assertEqual(model.model_params["model"], "private-model") + + @patch.dict( + "os.environ", + { + "FASTCHAT_API_BASE": "llmebench.qcri.org", + "FASTCHAT_API_KEY": "secret-key", + "FASTCHAT_MODEL": "private-model", + }, + ) + def test_fastchat_config_priority(self): + "Test if model config parameters override environment variables" + model = FastChatModel(model_name="another-model") + + self.assertEqual(openai.api_type, "openai") + self.assertEqual(openai.api_base, "llmebench.qcri.org") + self.assertEqual(openai.api_key, "secret-key") + self.assertEqual(model.model_params["model"], "another-model") diff --git a/tests/models/test_OpenAIModel.py b/tests/models/test_OpenAIModel.py index 4ba39cb1..e65d35d4 100644 --- a/tests/models/test_OpenAIModel.py +++ b/tests/models/test_OpenAIModel.py @@ -84,6 +84,9 @@ def test_openai_config_azure(self): "AZURE_API_URL": "url", "AZURE_API_KEY": "secret-key", "AZURE_ENGINE_NAME": "private-model", + "OPENAI_API_BASE": "", + "OPENAI_API_KEY": "", + "OPENAI_MODEL": "", }, ) def test_openai_config_env_var_azure(self):