Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/data 2002 generalise jinja params #42

Merged
merged 3 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions dagger/cli/module.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import click
from dagger.utilities.module import Module
from dagger.utils import Printer
import json


def parse_key_value(ctx, param, value):
#print('YYY', value)
if not value:
return {}
key_value_dict = {}
for pair in value:
try:
key, val_file_path = pair.split('=', 1)
#print('YYY', key, val_file_path, pair)
val = json.load(open(val_file_path))
key_value_dict[key] = val
except ValueError:
raise click.BadParameter(f"Key-value pair '{pair}' is not in the format key=value")
return key_value_dict

@click.command()
@click.option("--config_file", "-c", help="Path to module config file")
@click.option("--target_dir", "-t", help="Path to directory to generate the task configs to")
def generate_tasks(config_file: str, target_dir: str) -> None:
@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Path to jinja parameters json file in the format: <jinja_variable_name>=<path to json file>")
kiranvasudev marked this conversation as resolved.
Show resolved Hide resolved
def generate_tasks(config_file: str, target_dir: str, jinja_parameters: dict) -> None:
"""
Generating tasks for a module based on config
"""

module = Module(config_file, target_dir)
module = Module(config_file, target_dir, jinja_parameters)
module.generate_task_configs()

Printer.print_success("Tasks are successfully generated")
Expand Down
13 changes: 13 additions & 0 deletions dagger/dag_creator/airflow/operator_creator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from datetime import timedelta
from airflow.utils.task_group import TaskGroup

TIMEDELTA_PARAMETERS = ['execution_timeout']

Expand All @@ -11,6 +12,15 @@ def __init__(self, task, dag):
self._template_parameters = {}
self._airflow_parameters = {}

def _get_existing_task_group_or_create_new(self):
group_id = self._task.task_group
if self._dag.task_group:
for group in self._dag.task_group.children.values():
if isinstance(group, TaskGroup) and group.group_id == group_id:
return group

return TaskGroup(group_id=group_id, dag=self._dag)

@abstractmethod
def _create_operator(self, kwargs):
raise NotImplementedError
Expand All @@ -34,6 +44,9 @@ def _update_airflow_parameters(self):
if self._task.timeout_in_seconds:
self._airflow_parameters["execution_timeout"] = self._task.timeout_in_seconds

if self._task.task_group:
self._airflow_parameters["task_group"] = self._get_existing_task_group_or_create_new()

self._fix_timedelta_parameters()

def create_operator(self):
Expand Down
11 changes: 11 additions & 0 deletions dagger/pipeline/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def init_attributes(cls, orig_cls):
comment="Use dagger init-io cli",
),
Attribute(attribute_name="pool", required=False),
Attribute(
attribute_name="task_group",
required=False,
format_help=str,
comment="Task group name",
),
Attribute(
attribute_name="timeout_in_seconds",
required=False,
Expand Down Expand Up @@ -73,6 +79,7 @@ def __init__(self, name: str, pipeline_name, pipeline, config: dict):
self._outputs = []
self._pool = self.parse_attribute("pool") or self.default_pool
self._timeout_in_seconds = self.parse_attribute("timeout_in_seconds")
self._task_group = self.parse_attribute("task_group")
self.process_inputs(config["inputs"])
self.process_outputs(config["outputs"])

Expand Down Expand Up @@ -137,6 +144,10 @@ def pool(self):
def timeout_in_seconds(self):
return self._timeout_in_seconds

@property
def task_group(self):
return self._task_group

def add_input(self, task_input: IO):
_logger.info("Adding input: %s to task: %s", task_input.name, self._name)
self._inputs.append(task_input)
Expand Down
4 changes: 3 additions & 1 deletion dagger/utilities/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class Module:
def __init__(self, path_to_config, target_dir):
def __init__(self, path_to_config, target_dir, jinja_parameters):
kiranvasudev marked this conversation as resolved.
Show resolved Hide resolved
self._directory = path.dirname(path_to_config)
self._target_dir = target_dir or "./"
self._path_to_config = path_to_config
Expand All @@ -29,6 +29,7 @@ def __init__(self, path_to_config, target_dir):
self._branches_to_generate = config["branches_to_generate"]
self._override_parameters = config.get("override_parameters", {})
self._default_parameters = config.get("default_parameters", {})
self._jinja_parameters = jinja_parameters

@staticmethod
def read_yaml(yaml_str):
Expand Down Expand Up @@ -76,6 +77,7 @@ def generate_task_configs(self):
template_parameters.update(self._default_parameters or {})
template_parameters.update(attrs)
template_parameters['branch_name'] = branch_name
template_parameters.update(self._jinja_parameters)

dbt_manifest = None
if "dbt" in self._tasks.keys():
Expand Down
Loading