diff --git a/src/conductor/client/workflow/task/dynamic_fork_task.py b/src/conductor/client/workflow/task/dynamic_fork_task.py index daffa03f..439f1660 100644 --- a/src/conductor/client/workflow/task/dynamic_fork_task.py +++ b/src/conductor/client/workflow/task/dynamic_fork_task.py @@ -9,27 +9,21 @@ class DynamicForkTask(TaskInterface): - def __init__(self, task_ref_name: str, pre_fork_task: TaskInterface, join_task: JoinTask = None) -> Self: + def __init__(self, task_ref_name: str, tasks_param: str = 'dynamicTasks', tasks_input_param_name: str = 'dynamicTasksInputs', join_task: JoinTask = None) -> Self: super().__init__( task_reference_name=task_ref_name, task_type=TaskType.FORK_JOIN_DYNAMIC ) - self._pre_fork_task = deepcopy(pre_fork_task) + self.tasks_param = tasks_param + self.tasks_input_param_name = tasks_input_param_name self._join_task = deepcopy(join_task) def to_workflow_task(self) -> WorkflowTask: - workflow = super().to_workflow_task() - workflow.dynamic_fork_join_tasks_param = 'forkedTasks' - workflow.dynamic_fork_tasks_input_param_name = 'forkedTasksInputs' - workflow.input_parameters['forkedTasks'] = self._pre_fork_task.output_ref( - 'forkedTasks' - ) - workflow.input_parameters['forkedTasksInputs'] = self._pre_fork_task.output_ref( - 'forkedTasksInputs' - ) + wf_task = super().to_workflow_task() + wf_task.dynamic_fork_join_tasks_param = self.tasks_param + wf_task.dynamic_fork_tasks_input_param_name = self.tasks_input_param_name tasks = [ - self._pre_fork_task.to_workflow_task(), - workflow, + wf_task, ] if self._join_task != None: tasks.append(self._join_task.to_workflow_task()) diff --git a/src/conductor/client/workflow/task/task.py b/src/conductor/client/workflow/task/task.py index 48a6bf17..dfa8b18d 100644 --- a/src/conductor/client/workflow/task/task.py +++ b/src/conductor/client/workflow/task/task.py @@ -11,7 +11,13 @@ def get_task_interface_list_as_workflow_task_list(*tasks: Self) -> List[WorkflowTask]: converted_tasks = [] for task in tasks: - converted_tasks.append(task.to_workflow_task()) + wf_task = task.to_workflow_task() + if isinstance(wf_task, list): + # to_workflow_task() returned a list. E.g.: DynamicFork.to_workflow_task() returns the DynamicFork and the Join task. + for t in wf_task: + converted_tasks.append(t) + else: + converted_tasks.append(task.to_workflow_task()) return converted_tasks diff --git a/tests/integration/metadata/test_workflow_definition.py b/tests/integration/metadata/test_workflow_definition.py index ba7528bc..59aafa6a 100644 --- a/tests/integration/metadata/test_workflow_definition.py +++ b/tests/integration/metadata/test_workflow_definition.py @@ -167,7 +167,6 @@ def generate_set_variable_task() -> SetVariableTask: def generate_dynamic_fork_task() -> DynamicForkTask: return DynamicForkTask( task_ref_name='dynamic_fork', - pre_fork_task=generate_simple_task(10), join_task=JoinTask( 'join', join_on=[] ),