From b538cf1eae68a87810e804510a85ed29832a3a92 Mon Sep 17 00:00:00 2001 From: mspronesti Date: Wed, 25 Oct 2023 00:23:42 +0200 Subject: [PATCH] feat: add support for finetuned OpenAI models --- pandasai/helpers/openai_info.py | 7 +++++++ pandasai/llm/openai.py | 10 ++++++++-- tests/llms/test_openai.py | 7 +++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pandasai/helpers/openai_info.py b/pandasai/helpers/openai_info.py index c5438e5a1..5030f8d58 100644 --- a/pandasai/helpers/openai_info.py +++ b/pandasai/helpers/openai_info.py @@ -41,6 +41,10 @@ "gpt-35-turbo-16k-0613-completion": 0.004, # Others "text-davinci-003": 0.02, + # Fine-tuned input + "gpt-3.5-turbo-0613-finetuned": 0.012, + # Fine-tuned output + "gpt-3.5-turbo-0613-finetuned-completion": 0.016, } @@ -62,10 +66,13 @@ def get_openai_token_cost_for_model( float: Cost in USD. """ model_name = model_name.lower() + if "ft:" in model_name: + model_name = model_name.split(":")[1] + "-finetuned" if is_completion and ( model_name.startswith("gpt-4") or model_name.startswith("gpt-3.5") or model_name.startswith("gpt-35") + or "finetuned" in model_name ): # The cost of completion token is different from # the cost of prompt tokens. diff --git a/pandasai/llm/openai.py b/pandasai/llm/openai.py index 5fc536335..b7f84e2ab 100644 --- a/pandasai/llm/openai.py +++ b/pandasai/llm/openai.py @@ -102,9 +102,15 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ self.last_prompt = instruction.to_string() + suffix - if self.model in self._supported_chat_models: + if "ft:" in self.model: + # extract "standard" model name from fine-tuned model + model_name = self.model.split(":")[1] + else: + model_name = self.model + + if model_name in self._supported_chat_models: response = self.chat_completion(self.last_prompt) - elif self.model in self._supported_completion_models: + elif model_name in self._supported_completion_models: response = self.completion(self.last_prompt) else: raise UnsupportedModelError(self.model) diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 4a5b4c681..5fc5f7e3e 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -130,3 +130,10 @@ def test_call_supported_chat_model(self, mocker, prompt): result = openai.call(instruction=prompt) assert result == "response" + + def test_call_finetuned_model(self, mocker, prompt): + openai = OpenAI(api_token="test", model="ft:gpt-3.5-turbo:my-org:custom_suffix:id") + mocker.patch.object(openai, "chat_completion", return_value="response") + + result = openai.call(instruction=prompt) + assert result == "response"