Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastChat api model #198

Merged
merged 18 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions assets/ar/MT/AraBench_ar2en_Jais_ZeroShot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from llmebench.datasets import AraBenchDataset
from llmebench.models import FastChatModel
from llmebench.tasks import MachineTranslationTask


def metadata():
return {
"author": "Arabic Language Technologies, QCRI, HBKU",
"model": "jais-13b-chat",
"description": "Locally hosted Jais Chat 13b model using FastChat.",
}


def config():
return {
"dataset": AraBenchDataset,
"dataset_args": {
"src_lang": "ar",
"tgt_lang": "en",
},
"task": MachineTranslationTask,
"model": FastChatModel,
"model_args": {
"max_tries": 3,
},
}


def prompt(input_sample):
return [
{
"role": "system",
"content": "You are an expert translator specialized in translating texts from Arabic to English. You are concise as you only output the translation of the text without any illustrations or extra details",
},
{
"role": "user",
"content": f"Translate the following text to English.\nText: {input_sample}\nTranslation: ",
},
]


def post_process(response):
response = response["choices"][0]["message"]["content"]
response = response.replace('"', "")
response = response.strip()
return response
4 changes: 4 additions & 0 deletions envs/fastchat.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Sample env file for using a model hosted using FastChat
FASTCHAT_MODEL="..."
FASTCHAT_API_BASE="..."
FASTCHAT_API_KEY="..."
50 changes: 50 additions & 0 deletions llmebench/models/FastChat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

from llmebench.models.OpenAI import OpenAIModel


class FastChatModel(OpenAIModel):
"""
FastChat Model interface. Can be used for models hosted using FastChat
https://github.com/lm-sys/FastChat

Accepts all arguments used by `OpenAIModel`, and overrides the arguments listed
below with FastChat-specific variables.

See the [https://github.com/lm-sys/FastChat/blob/main/docs/model_support.md](model_support)
page in FastChat's documentation for supported models and instructions on extending
to custom models.

Arguments
---------
api_base : str
URL where the model is hosted. If not provided, the implementation will look at
environment variable `FASTCHAT_API_BASE`
api_key : str
Authentication token for the API. If not provided, the implementation will derive it
from environment variable `FASTCHAT_API_KEY`
model_name : str
Name of the model to use. If not provided, the implementation will derive it from
environment variable `FASTCHAT_MODEL`
"""

def __init__(self, api_base=None, api_key=None, model_name=None, **kwargs):
api_base = api_base or os.getenv("FASTCHAT_API_BASE")
api_key = api_key or os.getenv("FASTCHAT_API_KEY")
model_name = model_name or os.getenv("FASTCHAT_MODEL")
if api_base is None:
raise Exception(
"API url must be provided as model config or environment variable (`FASTCHAT_API_BASE`)"
)
if api_key is None:
raise Exception(
"API url must be provided as model config or environment variable (`FASTCHAT_API_KEY`)"
)
if model_name is None:
raise Exception(
"API url must be provided as model config or environment variable (`FASTCHAT_MODEL`)"
)
# checks for valid config settings)
super(FastChatModel, self).__init__(
api_base=api_base, api_key=api_key, model_name=model_name, **kwargs
)
1 change: 1 addition & 0 deletions llmebench/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .FastChat import FastChatModel
from .HuggingFaceInferenceAPI import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from .OpenAI import LegacyOpenAIModel, OpenAIModel
from .Petals import PetalsModel
Expand Down
76 changes: 76 additions & 0 deletions tests/models/test_FastChatModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import unittest
from unittest.mock import patch

import openai

from llmebench import Benchmark
from llmebench.models import FastChatModel

from tests.models.test_OpenAIModel import TestAssetsForOpenAIPrompts


class TestAssetsForFastChatPrompts(TestAssetsForOpenAIPrompts):
@classmethod
def setUpClass(cls):
# Load the benchmark assets
benchmark = Benchmark(benchmark_dir="assets")
all_assets = benchmark.find_assets()

# Filter out assets not using the GPT model
cls.assets = [
asset for asset in all_assets if asset["config"]["model"] in [FastChatModel]
]

def test_fastchat_prompts(self):
"Test if all assets using this model return data in an appropriate format for prompting"

self.test_openai_prompts()


class TestFastChatConfig(unittest.TestCase):
def test_fastchat_config(self):
"Test if model config parameters passed as arguments are used"
model = FastChatModel(
api_base="llmebench.qcri.org",
api_key="secret-key",
model_name="private-model",
)

self.assertEqual(openai.api_type, "openai")
self.assertEqual(openai.api_base, "llmebench.qcri.org")
self.assertEqual(openai.api_key, "secret-key")
self.assertEqual(model.model_params["model"], "private-model")

@patch.dict(
"os.environ",
{
"FASTCHAT_API_BASE": "llmebench.qcri.org",
"FASTCHAT_API_KEY": "secret-key",
"FASTCHAT_MODEL": "private-model",
},
)
def test_fastchat_config_env_var(self):
"Test if model config parameters passed as environment variables are used"
model = FastChatModel()

self.assertEqual(openai.api_type, "openai")
self.assertEqual(openai.api_base, "llmebench.qcri.org")
self.assertEqual(openai.api_key, "secret-key")
self.assertEqual(model.model_params["model"], "private-model")

@patch.dict(
"os.environ",
{
"FASTCHAT_API_BASE": "llmebench.qcri.org",
"FASTCHAT_API_KEY": "secret-key",
"FASTCHAT_MODEL": "private-model",
},
)
def test_fastchat_config_priority(self):
"Test if model config parameters override environment variables"
model = FastChatModel(model_name="another-model")

self.assertEqual(openai.api_type, "openai")
self.assertEqual(openai.api_base, "llmebench.qcri.org")
self.assertEqual(openai.api_key, "secret-key")
self.assertEqual(model.model_params["model"], "another-model")
3 changes: 3 additions & 0 deletions tests/models/test_OpenAIModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def test_openai_config_azure(self):
"AZURE_API_URL": "url",
"AZURE_API_KEY": "secret-key",
"AZURE_ENGINE_NAME": "private-model",
"OPENAI_API_BASE": "",
"OPENAI_API_KEY": "",
"OPENAI_MODEL": "",
},
)
def test_openai_config_env_var_azure(self):
Expand Down
Loading