Skip to content

Commit

Permalink
Merge branch 'refs/heads/feat/workflow-parallel-support' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Aug 24, 2024
2 parents 845d31b + 85d3197 commit 02b24fe
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 15 deletions.
3 changes: 2 additions & 1 deletion api/core/app/apps/workflow_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content
text=event.chunk_content,
from_variable_selector=event.from_variable_selector
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
Expand Down
2 changes: 2 additions & 0 deletions api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""


class QueueAgentMessageEvent(AppQueueEvent):
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/answer/answer_stream_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
return files

@classmethod
def _get_file_var_from_value(self, value: dict | list) -> Optional[dict]:
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
"""
Get file var from value
:param value: variable value
Expand Down
11 changes: 5 additions & 6 deletions api/core/workflow/nodes/end/end_stream_generate_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,13 @@ def extract_stream_variable_selector_from_node_data(cls,
if node_id != 'sys' and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get('data', {}).get('type')
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
if (
variable_selector.value_selector not in value_selectors
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == 'text'
):
value_selectors.append(variable_selector.value_selector)

# remove duplicates
value_selector_tuples = [tuple(item) for item in value_selectors]
unique_value_selector_tuples = list(set(value_selector_tuples))
value_selectors = [list(item) for item in unique_value_selector_tuples]

return value_selectors

@classmethod
Expand Down
117 changes: 110 additions & 7 deletions api/core/workflow/nodes/end/end_stream_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ class EndStreamProcessor(StreamProcessor):

def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.stream_param = graph.end_stream_param
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.end_stream_param = graph.end_stream_param
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_outputed = False
self.outputed_node_ids = set()

def process(self,
generator: Generator[GraphEngineEvent, None, None]
Expand All @@ -32,6 +36,15 @@ def process(self,

yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content

self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
yield event
continue

if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
Expand All @@ -42,23 +55,97 @@ def process(self,
event.route_node_state.node_id
] = stream_out_end_node_ids

for _ in stream_out_end_node_ids:
if stream_out_end_node_ids:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content

self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[end_node_id] += 1

del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]

# remove unreachable nodes
self._remove_unreachable_nodes(event)

# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event

def reset(self) -> None:
self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}

def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for end_node_id, position in self.route_position.items():
# all depends on end node id not in rest node ids
if (event.route_node_state.node_id != end_node_id
and (end_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
continue

route_position = self.route_position[end_node_id]

position = 0
value_selectors = []
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position >= route_position:
value_selectors.append(current_value_selectors)

position += 1

for value_selector in value_selectors:
if not value_selector:
continue

value = self.variable_pool.get(
value_selector
)

if value is None:
break

text = value.markdown

if text:
current_node_id = value_selector[0]
if self.has_outputed and current_node_id not in self.outputed_node_ids:
text = '\n' + text

self.outputed_node_ids.add(current_node_id)
self.has_outputed = True
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)

self.route_position[end_node_id] += 1

def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
Expand All @@ -73,14 +160,30 @@ def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[s
return []

stream_out_end_node_ids = []
for end_node_id, variable_selectors in self.end_streamed_variable_selectors.items():
for end_node_id, route_position in self.route_position.items():
if end_node_id not in self.rest_node_ids:
continue

# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.stream_param.end_dependencies[end_node_id]):
if stream_output_value_selector not in variable_selectors:
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
continue

position = 0
value_selector = None
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position == route_position:
value_selector = current_value_selectors
break

position += 1

if not value_selector:
continue

# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue

stream_out_end_node_ids.append(end_node_id)
Expand Down

0 comments on commit 02b24fe

Please sign in to comment.