From c44a7b5ccb83e7e53810766d1a0b57c5f50d78c4 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Wed, 12 Jun 2024 11:37:16 +0200 Subject: [PATCH] Replacing string replacement with jinja in module processor --- dagger/utilities/dbt_config_parser.py | 12 ++++++++++++ dagger/utilities/module.py | 28 ++++++++++++++++----------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 1b64132..3be57fe 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -35,6 +35,18 @@ def __init__(self, config_parameters: dict): self._nodes_in_manifest = self._manifest_data.get("nodes", {}) self._sources_in_manifest = self._manifest_data.get("sources", {}) + @property + def nodes_in_manifest(self): + return self._nodes_in_manifest + + @property + def sources_in_manifest(self): + return self._sources_in_manifest + + @property + def dbt_default_schema(self): + return self._default_schema + def _get_manifest_path(self) -> str: """ Construct path for manifest.json file based on configuration parameters. diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 968e196..ff1329f 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -6,6 +6,8 @@ DatabricksDBTConfigParser, ) +import jinja2 + import yaml _logger = logging.getLogger("root") @@ -48,19 +50,13 @@ def read_task_config(self, task): @staticmethod def replace_template_parameters(_task_str, _template_parameters): - for _key, _value in _template_parameters.items(): - if type(_value) == str: - try: - int_value = int(_value) - _value = f'"{_value}"' - except: - pass - locals()[_key] = _value + environment = jinja2.Environment() + template = environment.from_string(_task_str) + rendered_task = template.render(_template_parameters) return ( - _task_str.format(**locals()) - .replace("{", "{{") - .replace("}", "}}") + rendered_task + # TODO Remove this hack and use Jinja escaping instead of special expression in template files .replace("__CBS__", "{") .replace("__CBE__", "}") ) @@ -79,12 +75,22 @@ def generate_task_configs(self): template_parameters = {} template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) + template_parameters['branch_name'] = branch_name + + dbt_manifest = None if "dbt" in self._tasks.keys(): if template_parameters.get("profile_name") == "athena": self._dbt_module = AthenaDBTConfigParser(template_parameters) if template_parameters.get("profile_name") == "databricks": self._dbt_module = DatabricksDBTConfigParser(template_parameters) + dbt_manifest = {} + dbt_manifest['nodes'] = self._dbt_module.nodes_in_manifest + dbt_manifest['sources'] = self._dbt_module.sources_in_manifest + + template_parameters["dbt_manifest"] = dbt_manifest + template_parameters["dbt_default_schema"] = self._dbt_module.dbt_default_schema + for task, task_yaml in self._tasks.items(): task_name = f"{branch_name}_{task}" _logger.info(f"Generating task {task_name}")