Skip to content

Commit

Permalink
fix(llms): restore completion API for gpt-3.5-turbo-instruct (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
mspronesti authored Oct 4, 2023
1 parent 7856100 commit 39f9490
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pandasai/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class OpenAI(BaseOpenAI):
The list of supported Chat models includes ["gpt-4", "gpt-4-0613", "gpt-4-32k",
"gpt-4-32k-0613", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-instruct"].
The list of supported Completion models includes "gpt-3.5-turbo-instruct" and
"text-davinci-003" (soon to be deprecated).
"""

_supported_chat_models = [
Expand All @@ -41,8 +42,8 @@ class OpenAI(BaseOpenAI):
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct",
]
_supported_completion_models = ["text-davinci-003", "gpt-3.5-turbo-instruct"]

model: str = "gpt-3.5-turbo"

Expand Down Expand Up @@ -101,7 +102,9 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
self.last_prompt = instruction.to_string() + suffix

if self.model in self._supported_chat_models:
if self.model in self._supported_completion_models:
response = self.completion(self.last_prompt)
elif self.model in self._supported_chat_models:
response = self.chat_completion(self.last_prompt)
else:
raise UnsupportedOpenAIModelError("Unsupported model")
Expand Down
7 changes: 7 additions & 0 deletions tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ def test_call_with_unsupported_model(self, prompt):
llm = OpenAI(api_token="test", model="not a model")
llm.call(instruction=prompt)

def test_call_supported_completion_model(self, mocker, prompt):
openai = OpenAI(api_token="test", model="gpt-3.5-turbo-instruct")
mocker.patch.object(openai, "completion", return_value="response")

result = openai.call(instruction=prompt)
assert result == "response"

def test_call_supported_chat_model(self, mocker, prompt):
openai = OpenAI(api_token="test", model="gpt-4")
mocker.patch.object(openai, "chat_completion", return_value="response")
Expand Down

0 comments on commit 39f9490

Please sign in to comment.