diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 048f27f30592e1..95533ccfaf771a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -620,6 +620,9 @@ def _chat_generate( if "stream_options" in extra_model_kwargs: del extra_model_kwargs["stream_options"] + if "stop" in extra_model_kwargs: + del extra_model_kwargs["stop"] + # chat model response = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], @@ -635,7 +638,7 @@ def _chat_generate( block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) if block_as_stream: - return self._handle_chat_block_as_stream_response(block_result, prompt_messages) + return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) return block_result @@ -643,6 +646,7 @@ def _handle_chat_block_as_stream_response( self, block_result: LLMResult, prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, ) -> Generator[LLMResultChunk, None, None]: """ Handle llm chat response @@ -652,15 +656,22 @@ def _handle_chat_block_as_stream_response( :param response: response :param prompt_messages: prompt messages :param tools: tools for tool calling + :param stop: stop words :return: llm response chunk generator """ + text = block_result.message.content + text = cast(str, text) + + if stop: + text = self.enforce_stop_tokens(text, stop) + yield LLMResultChunk( model=block_result.model, prompt_messages=prompt_messages, system_fingerprint=block_result.system_fingerprint, delta=LLMResultChunkDelta( index=0, - message=block_result.message, + message=AssistantPromptMessage(content=text), finish_reason="stop", usage=block_result.usage, ), @@ -912,6 +923,20 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp ] ) + if model.startswith("o1"): + system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) + if system_message_count > 0: + new_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_message = UserPromptMessage( + content=prompt_message.content, + name=prompt_message.name, + ) + + new_prompt_messages.append(prompt_message) + prompt_messages = new_prompt_messages + return prompt_messages def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: