Skip to content

Commit

Permalink
Filter dbt cloud jobs by environment id
Browse files Browse the repository at this point in the history
  • Loading branch information
usefulalgorithm committed Jan 9, 2024
1 parent f23e5bf commit 761cfe9
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 3 deletions.
10 changes: 10 additions & 0 deletions metaphor/dbt/cloud/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ If you're using dbt [Single Tenancy](https://docs.getdbt.com/docs/cloud/about-cl
base_url: https://cloud.<tenant>.getdbt.com
```

#### Environment IDs

```yaml
environment_ids:
- <environment_id_1>
- <environment_id_2>
```

If `environment_ids` are specified, only jobs run within those environments are collected. If it is not provided, all dbt jobs will be collected.

## Testing

Follow the [Installation](../../README.md) instructions to install `metaphor-connectors` in your environment (or virtualenv). Make sure to include either `all` or `dbt` extra.
Expand Down
18 changes: 17 additions & 1 deletion metaphor/dbt/cloud/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ class DbtAdminAPIClient:
See https://docs.getdbt.com/dbt-cloud/api-v2 for more details.
"""

def __init__(self, base_url: str, account_id: int, service_token: str):
def __init__(
self,
base_url: str,
account_id: int,
service_token: str,
included_env_ids: Set[int] = set(),
):
self.admin_api_base_url = f"{base_url}/api/v2"
self.account_id = account_id
self.service_token = service_token
self.included_env_ids = included_env_ids

def _get(self, path: str, params: Optional[Dict] = None):
url = f"{self.admin_api_base_url}/accounts/{self.account_id}/{path}"
Expand Down Expand Up @@ -71,6 +78,15 @@ def get_project_jobs(self, project_id: int) -> List[int]:
jobs |= new_jobs
offset += page_size

def job_is_included(self, job_id: int) -> bool:
if len(self.included_env_ids) == 0:
# No excluded environment, just return True
return True

Check warning on line 84 in metaphor/dbt/cloud/client.py

View check run for this annotation

Codecov / codecov/patch

metaphor/dbt/cloud/client.py#L84

Added line #L84 was not covered by tests

resp = self._get(f"jobs/{job_id}")
data = resp.get("data")
return int(data.get("environment_id", -1)) in self.included_env_ids

def get_last_successful_run(self, job_id: int) -> DbtRun:
"""Get the run ID of the last successful run for a job"""

Expand Down
5 changes: 4 additions & 1 deletion metaphor/dbt/cloud/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import field as dataclass_field
from typing import List
from typing import List, Set

from pydantic.dataclasses import dataclass

Expand All @@ -22,6 +22,9 @@ class DbtCloudConfig(BaseConfig):
# dbt cloud project IDs
project_ids: List[int] = dataclass_field(default_factory=list)

# dbt cloud environment IDs to include. If specified, only jobs run in the provided environments will be crawled.
environment_ids: Set[int] = dataclass_field(default_factory=set)

# map meta field to ownerships
meta_ownerships: List[MetaOwnership] = dataclass_field(default_factory=list)

Expand Down
5 changes: 5 additions & 0 deletions metaphor/dbt/cloud/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, config: DbtCloudConfig):
base_url=self._base_url,
account_id=self._account_id,
service_token=self._service_token,
included_env_ids=config.environment_ids,
)

async def extract(self) -> Collection[ENTITY_TYPES]:
Expand All @@ -53,6 +54,10 @@ async def extract(self) -> Collection[ENTITY_TYPES]:
return [item for ls in self._entities.values() for item in ls]

async def _extract_last_run(self, job_id: int):
if not self._client.job_is_included(job_id):
logger.info(f"Ignoring job ID: {job_id}")
return

logger.info(f"Fetching metadata for job ID: {job_id}")

run = self._client.get_last_successful_run(job_id)
Expand Down
50 changes: 50 additions & 0 deletions tests/dbt/cloud/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,53 @@ def test_get_run_artifact(mock_requests):
},
timeout=600,
)


@patch("metaphor.dbt.cloud.client.requests")
def test_job_is_included(mock_requests):
client = DbtAdminAPIClient(
base_url="http://base.url",
account_id=1111,
service_token="service_token",
included_env_ids={1, 3},
)

def mock_get(url: str, **kwargs):
job_id = int(url.rsplit("/", 1)[-1])
if job_id == 1:
return Response(
200,
{
"data": {
"environment_id": 1,
}
},
)
elif job_id == 2:
return Response(
200,
{
"data": {
"environment_id": 2,
}
},
)
elif job_id == 3:
return Response(
200,
{
"data": {
"environment_id": 4,
}
},
)
return Response(404, {})

mock_requests.get = mock_get

for i in range(1, 4):
included = client.job_is_included(i)
if i == 1:
assert included
else:
assert not included
8 changes: 7 additions & 1 deletion tests/dbt/cloud/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ async def test_extractor(
)
)
mock_client.get_project_jobs = MagicMock(side_effect=[[8888], [2222]])

def mock_job_is_included(job_id: int) -> bool:
return job_id != 3333

mock_client.job_is_included = mock_job_is_included
mock_client.get_snowflake_account = MagicMock(return_value="snowflake_account")
mock_client.get_run_artifact = MagicMock(return_value="tempfile")

Expand All @@ -39,8 +44,9 @@ async def fake_extract():
config = DbtCloudConfig(
output=OutputConfig(),
account_id=1111,
job_ids=[2222],
job_ids=[2222, 3333],
project_ids=[6666, 4444],
environment_ids={1},
base_url="https://cloud.metaphor.getdbt.com",
service_token="service_token",
)
Expand Down

0 comments on commit 761cfe9

Please sign in to comment.