Skip to content

Commit

Permalink
feat: add support for finetuned OpenAI models (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
mspronesti authored Oct 24, 2023
1 parent 0d56e93 commit fe45cca
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
7 changes: 7 additions & 0 deletions pandasai/helpers/openai_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
"gpt-35-turbo-16k-0613-completion": 0.004,
# Others
"text-davinci-003": 0.02,
# Fine-tuned input
"gpt-3.5-turbo-0613-finetuned": 0.012,
# Fine-tuned output
"gpt-3.5-turbo-0613-finetuned-completion": 0.016,
}


Expand All @@ -62,10 +66,13 @@ def get_openai_token_cost_for_model(
float: Cost in USD.
"""
model_name = model_name.lower()
if "ft:" in model_name:
model_name = model_name.split(":")[1] + "-finetuned"
if is_completion and (
model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35")
or "finetuned" in model_name
):
# The cost of completion token is different from
# the cost of prompt tokens.
Expand Down
10 changes: 8 additions & 2 deletions pandasai/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,15 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
self.last_prompt = instruction.to_string() + suffix

if self.model in self._supported_chat_models:
if "ft:" in self.model:
# extract "standard" model name from fine-tuned model
model_name = self.model.split(":")[1]
else:
model_name = self.model

if model_name in self._supported_chat_models:
response = self.chat_completion(self.last_prompt)
elif self.model in self._supported_completion_models:
elif model_name in self._supported_completion_models:
response = self.completion(self.last_prompt)
else:
raise UnsupportedModelError(self.model)
Expand Down
7 changes: 7 additions & 0 deletions tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,10 @@ def test_call_supported_chat_model(self, mocker, prompt):

result = openai.call(instruction=prompt)
assert result == "response"

def test_call_finetuned_model(self, mocker, prompt):
openai = OpenAI(api_token="test", model="ft:gpt-3.5-turbo:my-org:custom_suffix:id")
mocker.patch.object(openai, "chat_completion", return_value="response")

result = openai.call(instruction=prompt)
assert result == "response"

0 comments on commit fe45cca

Please sign in to comment.