Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add o1-series models support in Agent App (ReACT only) #8350

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions api/core/model_runtime/model_providers/openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ def _chat_generate(
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]

if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]

# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
Expand All @@ -635,14 +638,15 @@ def _chat_generate(
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)

if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)

return block_result

def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
Expand All @@ -652,15 +656,22 @@ def _handle_chat_block_as_stream_response(
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)

if stop:
text = self.enforce_stop_tokens(text, stop)

yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
Expand Down Expand Up @@ -912,6 +923,20 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp
]
)

if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0:
new_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message = UserPromptMessage(
content=prompt_message.content,
name=prompt_message.name,
)

new_prompt_messages.append(prompt_message)
prompt_messages = new_prompt_messages

return prompt_messages

def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
Expand Down