diff --git a/dagger/cli/module.py b/dagger/cli/module.py index 67fca87..931e809 100644 --- a/dagger/cli/module.py +++ b/dagger/cli/module.py @@ -1,17 +1,34 @@ import click from dagger.utilities.module import Module from dagger.utils import Printer +import json +def parse_key_value(ctx, param, value): + #print('YYY', value) + if not value: + return {} + key_value_dict = {} + for pair in value: + try: + key, val_file_path = pair.split('=', 1) + #print('YYY', key, val_file_path, pair) + val = json.load(open(val_file_path)) + key_value_dict[key] = val + except ValueError: + raise click.BadParameter(f"Key-value pair '{pair}' is not in the format key=value") + return key_value_dict + @click.command() @click.option("--config_file", "-c", help="Path to module config file") @click.option("--target_dir", "-t", help="Path to directory to generate the task configs to") -def generate_tasks(config_file: str, target_dir: str) -> None: +@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Path to jinja parameters json file in the format: =") +def generate_tasks(config_file: str, target_dir: str, jinja_parameters: dict) -> None: """ Generating tasks for a module based on config """ - module = Module(config_file, target_dir) + module = Module(config_file, target_dir, jinja_parameters) module.generate_task_configs() Printer.print_success("Tasks are successfully generated") diff --git a/dagger/dag_creator/airflow/operator_creator.py b/dagger/dag_creator/airflow/operator_creator.py index fc46234..b6aa036 100644 --- a/dagger/dag_creator/airflow/operator_creator.py +++ b/dagger/dag_creator/airflow/operator_creator.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from datetime import timedelta +from airflow.utils.task_group import TaskGroup TIMEDELTA_PARAMETERS = ['execution_timeout'] @@ -11,6 +12,15 @@ def __init__(self, task, dag): self._template_parameters = {} self._airflow_parameters = {} + def _get_existing_task_group_or_create_new(self): + group_id = self._task.task_group + if self._dag.task_group: + for group in self._dag.task_group.children.values(): + if isinstance(group, TaskGroup) and group.group_id == group_id: + return group + + return TaskGroup(group_id=group_id, dag=self._dag) + @abstractmethod def _create_operator(self, kwargs): raise NotImplementedError @@ -34,6 +44,9 @@ def _update_airflow_parameters(self): if self._task.timeout_in_seconds: self._airflow_parameters["execution_timeout"] = self._task.timeout_in_seconds + if self._task.task_group: + self._airflow_parameters["task_group"] = self._get_existing_task_group_or_create_new() + self._fix_timedelta_parameters() def create_operator(self): diff --git a/dagger/pipeline/task.py b/dagger/pipeline/task.py index 26235bd..ce07aec 100644 --- a/dagger/pipeline/task.py +++ b/dagger/pipeline/task.py @@ -36,6 +36,12 @@ def init_attributes(cls, orig_cls): comment="Use dagger init-io cli", ), Attribute(attribute_name="pool", required=False), + Attribute( + attribute_name="task_group", + required=False, + format_help=str, + comment="Task group name", + ), Attribute( attribute_name="timeout_in_seconds", required=False, @@ -73,6 +79,7 @@ def __init__(self, name: str, pipeline_name, pipeline, config: dict): self._outputs = [] self._pool = self.parse_attribute("pool") or self.default_pool self._timeout_in_seconds = self.parse_attribute("timeout_in_seconds") + self._task_group = self.parse_attribute("task_group") self.process_inputs(config["inputs"]) self.process_outputs(config["outputs"]) @@ -137,6 +144,10 @@ def pool(self): def timeout_in_seconds(self): return self._timeout_in_seconds + @property + def task_group(self): + return self._task_group + def add_input(self, task_input: IO): _logger.info("Adding input: %s to task: %s", task_input.name, self._name) self._inputs.append(task_input) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index ff1329f..8697efa 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -14,7 +14,7 @@ class Module: - def __init__(self, path_to_config, target_dir): + def __init__(self, path_to_config, target_dir, jinja_parameters=None): self._directory = path.dirname(path_to_config) self._target_dir = target_dir or "./" self._path_to_config = path_to_config @@ -29,6 +29,7 @@ def __init__(self, path_to_config, target_dir): self._branches_to_generate = config["branches_to_generate"] self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) + self._jinja_parameters = jinja_parameters or {} @staticmethod def read_yaml(yaml_str): @@ -76,6 +77,7 @@ def generate_task_configs(self): template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) template_parameters['branch_name'] = branch_name + template_parameters.update(self._jinja_parameters) dbt_manifest = None if "dbt" in self._tasks.keys():