Skip to content

Commit

Permalink
Merge branch 'refs/heads/feat/workflow-parallel-support' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Sep 2, 2024
2 parents 08ea076 + 5ca9df6 commit a9e0776
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import queue
import threading
import time
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional

from flask import Flask, current_app
Expand Down Expand Up @@ -64,6 +64,8 @@ def __init__(
max_execution_steps: int,
max_execution_time: int
) -> None:
## init thread pool
self.thread_pool = ThreadPoolExecutor(max_workers=10)
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
Expand Down Expand Up @@ -368,7 +370,7 @@ def _run_parallel_branches(
q: queue.Queue = queue.Queue()

# Create a list to store the threads
threads = []
futures = []

# new thread
for edge in edge_mappings:
Expand All @@ -378,17 +380,16 @@ def _run_parallel_branches(
):
continue

thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})

threads.append(thread)
thread.start()
futures.append(
self.thread_pool.submit(self._run_parallel_node, **{
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
)

succeeded_count = 0
while True:
Expand All @@ -401,7 +402,7 @@ def _run_parallel_branches(
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(threads):
if succeeded_count == len(futures):
q.put(None)

continue
Expand All @@ -410,9 +411,8 @@ def _run_parallel_branches(
except queue.Empty:
continue

# Join all threads
for thread in threads:
thread.join()
# wait all threads
wait(futures)

# get final node id
final_node_id = parallel.end_to_node_id
Expand Down

0 comments on commit a9e0776

Please sign in to comment.