Skip to content

Commit

Permalink
add finish_reason to the LLM node output (#7498)
Browse files Browse the repository at this point in the history
  • Loading branch information
orangeclk authored Aug 21, 2024
1 parent 784b11c commit f53454f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def get_tool_call(tool_call_id: str):
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments

finish_reason = 'Unknown'
finish_reason = None # The default value of finish_reason is None

for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip()
Expand All @@ -437,6 +437,8 @@ def get_tool_call(tool_call_id: str):
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]"
continue

try:
chunk_json = json.loads(decoded_chunk)
Expand Down
15 changes: 10 additions & 5 deletions api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
}

# handle invoke result
result_text, usage = self._invoke_llm(
result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
Expand All @@ -129,7 +129,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:

outputs = {
'text': result_text,
'usage': jsonable_encoder(usage)
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}

return NodeRunResult(
Expand Down Expand Up @@ -167,14 +168,14 @@ def _invoke_llm(self, node_data_model: ModelConfig,
)

# handle invoke result
text, usage = self._handle_invoke_result(
text, usage, finish_reason = self._handle_invoke_result(
invoke_result=invoke_result
)

# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)

return text, usage
return text, usage, finish_reason

def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Expand All @@ -186,6 +187,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage
prompt_messages = []
full_text = ''
usage = None
finish_reason = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
Expand All @@ -201,10 +203,13 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage
if not usage and result.delta.usage:
usage = result.delta.usage

if not finish_reason and result.delta.finish_reason:
finish_reason = result.delta.finish_reason

if not usage:
usage = LLMUsage.empty_usage()

return full_text, usage
return full_text, usage, finish_reason

def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
)

# handle invoke result
result_text, usage = self._invoke_llm(
result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
Expand Down Expand Up @@ -93,6 +93,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
prompt_messages=prompt_messages
),
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}
outputs = {
'class_name': category_name
Expand Down

0 comments on commit f53454f

Please sign in to comment.