Skip to content

Commit

Permalink
profile overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
Nico Gelders committed Oct 22, 2023
1 parent ec1dc3b commit fca2613
Showing 7 changed files with 187 additions and 31 deletions.
3 changes: 3 additions & 0 deletions examples/sample_project/profiles.yml
Original file line number Diff line number Diff line change
@@ -10,3 +10,6 @@ sample_project:
test:
type: duckdb
path: "{{ env_var('DUCKDB_DB_FILE') }}"
override_in_test:
type: duckdb
path: /does/not/exist/so/override.duckdb
37 changes: 14 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions prefect_dbt_flow/dbt/__init__.py
Original file line number Diff line number Diff line change
@@ -38,9 +38,11 @@ class DbtProfile:
Args:
target: dbt target, usualy "dev" or "prod"
overrides: dbt profile overrides
"""

target: str
overrides: Optional[dict[str, str]] = None


@dataclass
75 changes: 75 additions & 0 deletions prefect_dbt_flow/dbt/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Logic to override dbt profiles.yml"""
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator

import yaml # type: ignore

from prefect_dbt_flow.dbt import DbtProfile, DbtProject


@contextmanager
def override_profile(
project: DbtProject, profile: DbtProfile
) -> Generator[DbtProject, None, None]:
"""
Override dbt profiles.yml with the given profile configuration.
Args:
project: A class that represents a dbt project configuration.
profile: A class that represents a dbt profile configuration.
Returns:
dbt_project: DbtProject.
"""
if not profile.overrides:
yield project
return

dbt_project_name = _get_dbt_project_name(Path(project.project_dir))
dbt_profile_path = Path(project.profiles_dir) / "profiles.yml"

existing_profile_content = {}

if dbt_profile_path.exists():
existing_profile_content = (
yaml.safe_load(dbt_profile_path.read_text())
.get(dbt_project_name, {})
.get("outputs", {})
.get(profile.target, {})
)

with TemporaryDirectory() as tmp_profiles_dir:
tmp_profiles_path = Path(tmp_profiles_dir) / "profiles.yml"
with open(tmp_profiles_path, "w") as tmp_profiles_file:
yaml.safe_dump(
{
dbt_project_name: {
"target": profile.target,
"outputs": {
profile.target: {
**existing_profile_content,
**profile.overrides,
}
},
},
},
tmp_profiles_file,
)

yield DbtProject(
name=project.name,
project_dir=project.project_dir,
profiles_dir=tmp_profiles_dir,
)


def _get_dbt_project_name(project_dir: Path) -> str:
dbt_project_path = project_dir / "dbt_project.yml"

if not dbt_project_path.exists():
raise ValueError(f"dbt_project.yml not found in {project_dir}")

with open(dbt_project_path) as f:
return yaml.safe_load(f)["name"]
21 changes: 13 additions & 8 deletions prefect_dbt_flow/dbt/tasks.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from prefect.futures import PrefectFuture

from prefect_dbt_flow.dbt import DbtNode, DbtProfile, DbtProject, DbtResourceType, cli
from prefect_dbt_flow.dbt.profile import override_profile

DBT_RUN_EMOJI = "🏃"
DBT_TEST_EMOJI = "🧪"
@@ -43,8 +44,9 @@ def dbt_snapshot():
Returns:
None
"""
dbt_snapshot_output = cli.dbt_snapshot(project, profile, dbt_node.name)
get_run_logger().info(dbt_snapshot_output)
with override_profile(project, profile) as _project:
dbt_snapshot_output = cli.dbt_snapshot(_project, profile, dbt_node.name)
get_run_logger().info(dbt_snapshot_output)

return dbt_snapshot

@@ -80,8 +82,9 @@ def dbt_seed():
Returns:
None
"""
dbt_seed_output = cli.dbt_seed(project, profile, dbt_node.name)
get_run_logger().info(dbt_seed_output)
with override_profile(project, profile) as _project:
dbt_seed_output = cli.dbt_seed(_project, profile, dbt_node.name)
get_run_logger().info(dbt_seed_output)

return dbt_seed

@@ -117,8 +120,9 @@ def dbt_run():
Returns:
None
"""
dbt_run_output = cli.dbt_run(project, profile, dbt_node.name)
get_run_logger().info(dbt_run_output)
with override_profile(project, profile) as _project:
dbt_run_output = cli.dbt_run(_project, profile, dbt_node.name)
get_run_logger().info(dbt_run_output)

return dbt_run

@@ -154,8 +158,9 @@ def dbt_test():
Returns:
None
"""
dbt_test_output = cli.dbt_test(project, profile, dbt_node.name)
get_run_logger().info(dbt_test_output)
with override_profile(project, profile) as _project:
dbt_test_output = cli.dbt_test(_project, profile, dbt_node.name)
get_run_logger().info(dbt_test_output)

return dbt_test

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ mkdocstrings = { extras = ["python"], version = "^0.23.0" }
coverage = "^7.3.2"
pytest-xdist = "^3.3.1"
pytest-cov = "^4.1.0"
types-pyyaml = "^6.0.12.12"

[build-system]
requires = ["poetry-core"]
79 changes: 79 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
@@ -163,3 +163,82 @@ def test_flow_jaffle_shop(duckdb_db_file: Path):

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 9


def test_flow_sample_project_overrides_new_profile(duckdb_db_file: Path):
dbt_project_path = SAMPLE_PROJECT_PATH

my_dbt_flow = dbt_flow(
project=DbtProject(
name="sample_project",
project_dir=dbt_project_path,
profiles_dir=dbt_project_path,
),
profile=DbtProfile(
target="something_else",
overrides={
"type": "duckdb",
"path": str(duckdb_db_file.absolute()),
},
),
flow_kwargs={
# Ensure only one process has access to the duckdb db
# file at the same time
"task_runner": SequentialTaskRunner(),
},
)

my_dbt_flow()

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4


def test_flow_sample_project_overrides_existing_profile(duckdb_db_file: Path):
dbt_project_path = SAMPLE_PROJECT_PATH

my_dbt_flow = dbt_flow(
project=DbtProject(
name="sample_project",
project_dir=dbt_project_path,
profiles_dir=dbt_project_path,
),
profile=DbtProfile(
target="override_in_test",
overrides={
"path": str(duckdb_db_file.absolute()),
},
),
flow_kwargs={
# Ensure only one process has access to the duckdb db
# file at the same time
"task_runner": SequentialTaskRunner(),
},
)

my_dbt_flow()

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4


def test_flow_sample_project_dont_specify_target(duckdb_db_file: Path):
dbt_project_path = SAMPLE_PROJECT_PATH

my_dbt_flow = dbt_flow(
project=DbtProject(
name="sample_project",
project_dir=dbt_project_path,
profiles_dir=dbt_project_path,
),
flow_kwargs={
# Ensure only one process has access to the duckdb db
# file at the same time
"task_runner": SequentialTaskRunner(),
},
)

my_dbt_flow()

with duckdb.connect(str(duckdb_db_file)) as ddb:
assert len(ddb.sql("SHOW ALL TABLES").fetchall()) == 4

0 comments on commit fca2613

Please sign in to comment.