diff --git a/examples/sample_project/profiles.yml b/examples/sample_project/profiles.yml index 50956a0..1125292 100644 --- a/examples/sample_project/profiles.yml +++ b/examples/sample_project/profiles.yml @@ -13,3 +13,6 @@ sample_project: override_in_test: type: duckdb path: /does/not/exist/so/override.duckdb + vars_test: + type: "{{ var('adapter_type') }}" + path: "{{ var('duckdb_db_path') }}" diff --git a/prefect_dbt_flow/dbt/__init__.py b/prefect_dbt_flow/dbt/__init__.py index bbf1732..80ebe2d 100644 --- a/prefect_dbt_flow/dbt/__init__.py +++ b/prefect_dbt_flow/dbt/__init__.py @@ -74,8 +74,12 @@ class DbtDagOptions: select: dbt module to include in the run exclude: dbt module to exclude in the run run_test_after_model: run test afeter run model + vars: dbt vars + install_deps: install dbt dependencies, default behavior install deps """ select: Optional[str] = None exclude: Optional[str] = None run_test_after_model: bool = False + vars: Optional[dict[str, str]] = None + install_deps: bool = True diff --git a/prefect_dbt_flow/dbt/cli.py b/prefect_dbt_flow/dbt/cli.py index b822f89..845235e 100644 --- a/prefect_dbt_flow/dbt/cli.py +++ b/prefect_dbt_flow/dbt/cli.py @@ -1,4 +1,5 @@ """Utility functions for interacting with dbt using command-line commands.""" +import json import shutil from typing import Optional @@ -11,6 +12,7 @@ def dbt_ls( project: DbtProject, dag_options: Optional[DbtDagOptions], + profile: Optional[DbtProfile], output: str = "json", ) -> str: """ @@ -19,6 +21,7 @@ def dbt_ls( Args: project: A class that represents a dbt project configuration. dag_options: A class to add dbt DAG configurations. + profile: A class that represents a dbt profile configuration. output: Format of output, default is JSON. Returns: @@ -29,11 +32,16 @@ def dbt_ls( dbt_ls_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) dbt_ls_cmd.extend(["--output", output]) + if profile: + dbt_ls_cmd.extend(["-t", profile.target]) + if dag_options: if dag_options.select: dbt_ls_cmd.extend(["--select", dag_options.select]) if dag_options.exclude: dbt_ls_cmd.extend(["--exclude", dag_options.exclude]) + if dag_options.vars: + dbt_ls_cmd.extend(["--vars", f"'{json.dumps(dag_options.vars)}'"]) return cmd.run(" ".join(dbt_ls_cmd)) @@ -42,6 +50,7 @@ def dbt_run( project: DbtProject, model: str, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], ) -> str: """ Function that executes `dbt run` command @@ -50,6 +59,7 @@ def dbt_run( project: A class that represents a dbt project configuration. model: Name of the model to run. profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. Returns: A string representing the output of the `dbt run` command. @@ -62,6 +72,10 @@ def dbt_run( if profile: dbt_run_cmd.extend(["-t", profile.target]) + if dag_options: + if dag_options.vars: + dbt_run_cmd.extend(["--vars", f"'{json.dumps(dag_options.vars)}'"]) + return cmd.run(" ".join(dbt_run_cmd)) @@ -69,6 +83,7 @@ def dbt_test( project: DbtProject, model: str, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], ) -> str: """ Function that executes `dbt test` command @@ -77,6 +92,7 @@ def dbt_test( project: A class that represents a dbt project configuration. model: Name of the model to run. profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. Returns: A string representing the output of the `dbt test` command. @@ -89,6 +105,10 @@ def dbt_test( if profile: dbt_test_cmd.extend(["-t", profile.target]) + if dag_options: + if dag_options.vars: + dbt_test_cmd.extend(["--vars", f"'{json.dumps(dag_options.vars)}'"]) + return cmd.run(" ".join(dbt_test_cmd)) @@ -96,6 +116,7 @@ def dbt_seed( project: DbtProject, seed: str, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], ) -> str: """ Function that executes `dbt seed` command @@ -104,7 +125,7 @@ def dbt_seed( project: A class that represents a dbt project configuration. seed: Name of the seed to run. profile: A class that represents a dbt profile configuration. - + dag_options: A class to add dbt DAG configurations. Returns: A string representing the output of the `dbt seed` command @@ -117,6 +138,10 @@ def dbt_seed( if profile: dbt_seed_cmd.extend(["-t", profile.target]) + if dag_options: + if dag_options.vars: + dbt_seed_cmd.extend(["--vars", f"'{json.dumps(dag_options.vars)}'"]) + return cmd.run(" ".join(dbt_seed_cmd)) @@ -124,6 +149,7 @@ def dbt_snapshot( project: DbtProject, snapshot: str, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], ) -> str: """ Function that executes `dbt snapshot` command @@ -132,17 +158,51 @@ def dbt_snapshot( project: A class that represents a dbt project configuration. snapshot: Name of the snapshot to run. profile: A class that represents a dbt profile configuration. - + dag_options: A class to add dbt DAG configurations. Returns: A string representing the output of the `dbt snapshot` command """ - dbt_seed_cmd = [DBT_EXE, "snapshot"] - dbt_seed_cmd.extend(["--project-dir", str(project.project_dir)]) - dbt_seed_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) - dbt_seed_cmd.extend(["--select", snapshot]) + dbt_snapshot_cmd = [DBT_EXE, "snapshot"] + dbt_snapshot_cmd.extend(["--project-dir", str(project.project_dir)]) + dbt_snapshot_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) + dbt_snapshot_cmd.extend(["--select", snapshot]) if profile: - dbt_seed_cmd.extend(["-t", profile.target]) + dbt_snapshot_cmd.extend(["-t", profile.target]) - return cmd.run(" ".join(dbt_seed_cmd)) + if dag_options: + if dag_options.vars: + dbt_snapshot_cmd.extend(["--vars", f"'{json.dumps(dag_options.vars)}'"]) + + return cmd.run(" ".join(dbt_snapshot_cmd)) + + +def dbt_deps( + project: DbtProject, + profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], +) -> str: + """ + Function that executes `dbt deps` command + + Args: + project: A class that represents a dbt project configuration. + profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. + + Returns: + A string representing the output of the `dbt deps` command + """ + dbt_deps_cmd = [DBT_EXE, "deps"] + dbt_deps_cmd.extend(["--project-dir", str(project.project_dir)]) + dbt_deps_cmd.extend(["--profiles-dir", str(project.profiles_dir)]) + + if profile: + dbt_deps_cmd.extend(["-t", profile.target]) + + if dag_options: + if dag_options.vars: + dbt_deps_cmd.extend(["--vars", f"'{json.dumps(dag_options.vars)}'"]) + + return cmd.run(" ".join(dbt_deps_cmd)) diff --git a/prefect_dbt_flow/dbt/graph.py b/prefect_dbt_flow/dbt/graph.py index 3cd6ca6..13f1844 100644 --- a/prefect_dbt_flow/dbt/graph.py +++ b/prefect_dbt_flow/dbt/graph.py @@ -5,20 +5,25 @@ from prefect_dbt_flow.dbt import ( DbtDagOptions, DbtNode, + DbtProfile, DbtProject, DbtResourceType, cli, ) +from prefect_dbt_flow.dbt.profile import override_profile def parse_dbt_project( - project: DbtProject, dag_options: Optional[DbtDagOptions] = None + project: DbtProject, + profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions] = None, ) -> List[DbtNode]: """ Parses a list of dbt nodes class objects from dbt ls cli command. Args: project: A class that represents a dbt project configuration. + profile: A class that represents a dbt profile configuration. dag_options: A class to add dbt DAG configurations. Returns: @@ -27,7 +32,11 @@ def parse_dbt_project( dbt_graph: List[DbtNode] = [] models_with_tests: List[str] = [] - dbt_ls_output = cli.dbt_ls(project, dag_options) + with override_profile(project, profile) as _project: + if not dag_options or dag_options.install_deps: + cli.dbt_deps(_project, profile, dag_options) + + dbt_ls_output = cli.dbt_ls(_project, dag_options, profile) for line in dbt_ls_output.split("\n"): try: diff --git a/prefect_dbt_flow/dbt/tasks.py b/prefect_dbt_flow/dbt/tasks.py index 8918901..a4aebc7 100644 --- a/prefect_dbt_flow/dbt/tasks.py +++ b/prefect_dbt_flow/dbt/tasks.py @@ -4,7 +4,14 @@ from prefect import Task, get_run_logger, task from prefect.futures import PrefectFuture -from prefect_dbt_flow.dbt import DbtNode, DbtProfile, DbtProject, DbtResourceType, cli +from prefect_dbt_flow.dbt import ( + DbtDagOptions, + DbtNode, + DbtProfile, + DbtProject, + DbtResourceType, + cli, +) from prefect_dbt_flow.dbt.profile import override_profile DBT_RUN_EMOJI = "🏃" @@ -16,6 +23,7 @@ def _task_dbt_snapshot( project: DbtProject, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -25,6 +33,7 @@ def _task_dbt_snapshot( Args: project: A class that represents a dbt project configuration. profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. dbt_node: A class that represents the dbt node (model) to run. task_kwargs: Additional task configuration. @@ -45,7 +54,13 @@ def dbt_snapshot(): None """ with override_profile(project, profile) as _project: - dbt_snapshot_output = cli.dbt_snapshot(_project, dbt_node.name, profile) + if not dag_options or dag_options.install_deps: + dbt_deps_output = cli.dbt_deps(_project, profile, dag_options) + get_run_logger().info(dbt_deps_output) + + dbt_snapshot_output = cli.dbt_snapshot( + _project, dbt_node.name, profile, dag_options + ) get_run_logger().info(dbt_snapshot_output) return dbt_snapshot @@ -54,6 +69,7 @@ def dbt_snapshot(): def _task_dbt_seed( project: DbtProject, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -63,6 +79,7 @@ def _task_dbt_seed( Args: project: A class that represents a dbt project configuration. profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. dbt_node: A class that represents the dbt node (model) to run. task_kwargs: Additional task configuration. @@ -83,7 +100,13 @@ def dbt_seed(): None """ with override_profile(project, profile) as _project: - dbt_seed_output = cli.dbt_seed(_project, dbt_node.name, profile) + if not dag_options or dag_options.install_deps: + dbt_deps_output = cli.dbt_deps(_project, profile, dag_options) + get_run_logger().info(dbt_deps_output) + + dbt_seed_output = cli.dbt_seed( + _project, dbt_node.name, profile, dag_options + ) get_run_logger().info(dbt_seed_output) return dbt_seed @@ -92,6 +115,7 @@ def dbt_seed(): def _task_dbt_run( project: DbtProject, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -101,6 +125,7 @@ def _task_dbt_run( Args: project: A class that represents a dbt project configuration. profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. dbt_node: A class that represents the dbt node (model) to run. task_kwargs: Additional task configuration. @@ -121,7 +146,11 @@ def dbt_run(): None """ with override_profile(project, profile) as _project: - dbt_run_output = cli.dbt_run(_project, dbt_node.name, profile) + if not dag_options or dag_options.install_deps: + dbt_deps_output = cli.dbt_deps(_project, profile, dag_options) + get_run_logger().info(dbt_deps_output) + + dbt_run_output = cli.dbt_run(_project, dbt_node.name, profile, dag_options) get_run_logger().info(dbt_run_output) return dbt_run @@ -130,6 +159,7 @@ def dbt_run(): def _task_dbt_test( project: DbtProject, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], dbt_node: DbtNode, task_kwargs: Optional[Dict] = None, ) -> Task: @@ -139,6 +169,7 @@ def _task_dbt_test( Args: project: A class that represents a dbt project configuration. profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. dbt_node: A class that represents the dbt node (model) to run. task_kwargs: Additional task configuration. @@ -159,7 +190,13 @@ def dbt_test(): None """ with override_profile(project, profile) as _project: - dbt_test_output = cli.dbt_test(_project, dbt_node.name, profile) + if not dag_options or dag_options.install_deps: + dbt_deps_output = cli.dbt_deps(_project, profile, dag_options) + get_run_logger().info(dbt_deps_output) + + dbt_test_output = cli.dbt_test( + _project, dbt_node.name, profile, dag_options + ) get_run_logger().info(dbt_test_output) return dbt_test @@ -175,6 +212,7 @@ def dbt_test(): def generate_tasks_dag( project: DbtProject, profile: Optional[DbtProfile], + dag_options: Optional[DbtDagOptions], dbt_graph: List[DbtNode], run_test_after_model: bool = False, ) -> None: @@ -184,6 +222,7 @@ def generate_tasks_dag( Args: project: A class that represents a dbt project configuration. profile: A class that represents a dbt profile configuration. + dag_options: A class to add dbt DAG configurations. dbt_graph: A list of dbt nodes (models) to include in the DAG. run_test_after_model: If True, run tests after running each model. @@ -196,6 +235,7 @@ def generate_tasks_dag( dbt_node.unique_id: RESOURCE_TYPE_TO_TASK[dbt_node.resource_type]( project=project, profile=profile, + dag_options=dag_options, dbt_node=dbt_node, ) for dbt_node in dbt_graph @@ -214,6 +254,7 @@ def generate_tasks_dag( test_task = _task_dbt_test( project=project, profile=profile, + dag_options=dag_options, dbt_node=node, ) test_task_future = test_task.submit(wait_for=run_task_future) diff --git a/prefect_dbt_flow/flow.py b/prefect_dbt_flow/flow.py index 63e81aa..81fe8f7 100644 --- a/prefect_dbt_flow/flow.py +++ b/prefect_dbt_flow/flow.py @@ -29,7 +29,7 @@ def dbt_flow( **(flow_kwargs or {}), } - dbt_graph = graph.parse_dbt_project(project, dag_options) + dbt_graph = graph.parse_dbt_project(project, profile, dag_options) @flow(**all_flow_kwargs) def dbt_flow(): @@ -42,6 +42,7 @@ def dbt_flow(): tasks.generate_tasks_dag( project, profile, + dag_options, dbt_graph, dag_options.run_test_after_model if dag_options else False, ) diff --git a/tests/test_flow.py b/tests/test_flow.py index b063277..5995f49 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1,3 +1,5 @@ +import shutil +from contextlib import contextmanager from pathlib import Path import duckdb @@ -26,6 +28,20 @@ def duckdb_db_file(monkeypatch, tmp_path: Path): yield duckdb_db_file +@contextmanager +def dbt_package(project_path: Path, content: str): + package_yaml_path = project_path / "packages.yml" + dbt_packages_path = project_path / "dbt_packages" + + package_yaml_path.write_text(content) + + try: + yield + finally: + shutil.rmtree(str(dbt_packages_path.absolute())) + package_yaml_path.unlink(missing_ok=True) + + def test_flow_sample_project(duckdb_db_file: Path): dbt_project_path = SAMPLE_PROJECT_PATH @@ -249,3 +265,68 @@ def test_flow_sample_project_dont_specify_target(duckdb_db_file: Path): with duckdb.connect(str(duckdb_db_file)) as ddb: assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4 + + +def test_flow_sample_project_vars(duckdb_db_file: Path): + dbt_project_path = SAMPLE_PROJECT_PATH + + my_dbt_flow = dbt_flow( + project=DbtProject( + name="sample_project", + project_dir=dbt_project_path, + profiles_dir=dbt_project_path, + ), + profile=DbtProfile( + target="vars_test", + ), + dag_options=DbtDagOptions( + vars={ + "adapter_type": "duckdb", + "duckdb_db_path": str(duckdb_db_file.absolute()), + }, + ), + flow_kwargs={ + # Ensure only one process has access to the duckdb db + # file at the same time + "task_runner": SequentialTaskRunner(), + }, + ) + + my_dbt_flow() + + with duckdb.connect(str(duckdb_db_file)) as ddb: + assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4 + + +def test_flow_sample_project_install_deps(duckdb_db_file: Path): + dbt_project_path = SAMPLE_PROJECT_PATH + packages_yml_content = ( + """packages:\n""" + """ - package: dbt-labs/dbt_utils\n""" + """ version: "{{ var('dbt_utils_version') }}"\n""" + ) + + my_dbt_flow = dbt_flow( + project=DbtProject( + name="sample_project", + project_dir=dbt_project_path, + profiles_dir=dbt_project_path, + ), + dag_options=DbtDagOptions( + vars={ + "dbt_utils_version": "1.1.1", + }, + install_deps=True, + ), + flow_kwargs={ + # Ensure only one process has access to the duckdb db + # file at the same time + "task_runner": SequentialTaskRunner(), + }, + ) + + with dbt_package(dbt_project_path, content=packages_yml_content): + my_dbt_flow() + + with duckdb.connect(str(duckdb_db_file)) as ddb: + assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4