Skip to content

Commit

Permalink
[dagster-dlift] job + run methods (#25486)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Methods for manipulating jobs, launching runs, and retrieving run
artifacts.
## How I Tested These Changes
Had to make a few changes to unit tests to support. Also added a new
live test which creates a job, runs a few methods on the job, launches a
run, waits for it's completion, and then destroys the job.
  • Loading branch information
dpeng817 authored Nov 1, 2024
1 parent 0ce6970 commit d4fa3bc
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def build_dagster_oss_nightly_steps() -> List[BuildkiteStep]:
name="dbt-cloud-live-tests",
env_vars=[
"KS_DBT_CLOUD_ACCOUNT_ID",
"KS_DBT_CLOUD_PROJECT_ID",
"KS_DBT_CLOUD_TOKEN",
"KS_DBT_CLOUD_ACCESS_URL",
"KS_DBT_CLOUD_DISCOVERY_API_URL",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def k8s_extra_cmds(version: str, _) -> List[str]:
name=":dbt: dlift live",
env_vars=[
"KS_DBT_CLOUD_ACCOUNT_ID",
"KS_DBT_CLOUD_PROJECT_ID",
"KS_DBT_CLOUD_TOKEN",
"KS_DBT_CLOUD_ACCESS_URL",
"KS_DBT_CLOUD_DISCOVERY_API_URL",
Expand Down
96 changes: 90 additions & 6 deletions examples/experimental/dagster-dlift/dagster_dlift/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Iterator, Mapping, Sequence
import time
from typing import Any, Iterator, Mapping, Optional, Sequence

import requests

Expand All @@ -8,6 +9,7 @@
GET_DBT_TESTS_QUERY,
VERIFICATION_QUERY,
)
from dagster_dlift.utils import get_job_name

ENVIRONMENTS_SUBPATH = "environments/"

Expand All @@ -16,7 +18,7 @@ class DbtCloudClient:
def __init__(
self,
# Can be found on the Account Info page of dbt.
account_id: str,
account_id: int,
# Can be either a personal token or a service token.
token: str,
# Can be found on the
Expand Down Expand Up @@ -44,12 +46,28 @@ def get_session(self) -> requests.Session:
)
return session

def make_access_api_request(self, subpath: str) -> Mapping[str, Any]:
def get_artifact_session(self) -> requests.Session:
session = requests.Session()
session.headers.update(
{
"Authorization": f"Token {self.token}",
"Content-Type": "application/json",
}
)
return session

def make_access_api_request(
self, subpath: str, params: Optional[Mapping[str, Any]] = None
) -> Mapping[str, Any]:
session = self.get_session()
return self.ensure_valid_response(session.get(f"{self.get_api_v2_url()}/{subpath}")).json()
return self.ensure_valid_response(
session.get(f"{self.get_api_v2_url()}/{subpath}", params=params)
).json()

def ensure_valid_response(self, response: requests.Response) -> requests.Response:
if response.status_code != 200:
def ensure_valid_response(
self, response: requests.Response, expected_code: int = 200
) -> requests.Response:
if response.status_code != expected_code:
raise Exception(f"Request to DBT Cloud failed: {response.text}")
return response

Expand Down Expand Up @@ -134,3 +152,69 @@ def get_dbt_tests(self, environment_id: int) -> Sequence[Mapping[str, Any]]:
GET_DBT_TESTS_QUERY, {"environmentId": environment_id}, key="tests"
)
]

def create_dagster_job(self, project_id: int, environment_id: int) -> int:
"""Creats a dbt cloud job spec'ed to do what dagster expects."""
session = self.get_session()
response = self.ensure_valid_response(
session.post(
f"{self.get_api_v2_url()}/jobs/",
json={
"account_id": self.account_id,
"environment_id": environment_id,
"project_id": project_id,
"name": get_job_name(environment_id=environment_id, project_id=project_id),
"description": "A job that runs dbt models, sources, and tests.",
"job_type": "other",
},
),
expected_code=201,
).json()
return response["data"]["id"]

def destroy_dagster_job(self, project_id: int, environment_id: int, job_id: int) -> None:
"""Destroys a dagster job."""
session = self.get_session()
self.ensure_valid_response(session.delete(f"{self.get_api_v2_url()}/jobs/{job_id}"))

def get_job_info_by_id(self, job_id: int) -> Mapping[str, Any]:
session = self.get_session()
return self.ensure_valid_response(
session.get(f"{self.get_api_v2_url()}/jobs/{job_id}")
).json()

def list_jobs(self, environment_id: int) -> Sequence[Mapping[str, Any]]:
return self.make_access_api_request("/jobs/", params={"environment_id": environment_id})[
"data"
]

def trigger_job(self, job_id: int, steps: Optional[Sequence[str]] = None) -> Mapping[str, Any]:
session = self.get_session()
response = self.ensure_valid_response(
session.post(
f"{self.get_api_v2_url()}/jobs/{job_id}/run/",
json={"steps_override": steps, "cause": "Triggered by dagster."},
)
)
return response.json()

def get_job_run_info(self, job_run_id: int) -> Mapping[str, Any]:
session = self.get_session()
return self.ensure_valid_response(
session.get(f"{self.get_api_v2_url()}/runs/{job_run_id}")
).json()

def poll_for_run_completion(self, job_run_id: int, timeout: int = 60) -> int:
start_time = time.time()
while time.time() - start_time < timeout:
run_info = self.get_job_run_info(job_run_id)
if run_info["data"]["status"] in {10, 20, 30}:
return run_info["data"]["status"]
time.sleep(0.1)
raise Exception(f"Run {job_run_id} did not complete within {timeout} seconds.")

def get_run_results_json(self, job_run_id: int) -> Mapping[str, Any]:
session = self.get_artifact_session()
return self.ensure_valid_response(
session.get(f"{self.get_api_v2_url()}/runs/{job_run_id}/artifacts/run_results.json")
).json()
16 changes: 16 additions & 0 deletions examples/experimental/dagster-dlift/dagster_dlift/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DbtCloudContentType,
DbtCloudProjectEnvironmentData,
)
from dagster_dlift.utils import get_job_name


def compute_environment_data(
Expand All @@ -13,6 +14,7 @@ def compute_environment_data(
return DbtCloudProjectEnvironmentData(
project_id=project_id,
environment_id=environment_id,
job_id=get_or_create_job(environment_id, project_id, client),
models_by_unique_id={
model["uniqueId"]: DbtCloudContentData(
content_type=DbtCloudContentType.MODEL,
Expand All @@ -35,3 +37,17 @@ def compute_environment_data(
for test in client.get_dbt_tests(environment_id)
},
)


def get_or_create_job(environment_id: int, project_id: int, client: DbtCloudClient) -> int:
"""Get or create a dbt Cloud job for a project environment."""
expected_job_name = get_job_name(project_id, environment_id)
if expected_job_name in {
job["name"] for job in client.list_jobs(environment_id=environment_id)
}:
return next(
job["id"]
for job in client.list_jobs(environment_id=environment_id)
if job["name"] == expected_job_name
)
return client.create_dagster_job(project_id, environment_id)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class DbtCloudCredentials(NamedTuple):
account_id: str
account_id: int
token: str
access_url: str
discovery_api_url: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ def __hash__(self) -> int:

class ExpectedAccessApiRequest(NamedTuple):
subpath: str
params: Optional[Mapping[str, Any]] = None

def __hash__(self) -> int:
return hash(self.subpath)
return hash((self.subpath, frozenset(self.params.items() if self.params else [])))


class DbtCloudClientFake(DbtCloudClient):
Expand All @@ -30,12 +31,15 @@ def __init__(
self.access_api_responses = access_api_responses
self.discovery_api_responses = discovery_api_responses

def make_access_api_request(self, subpath: str) -> Mapping[str, Any]:
if ExpectedAccessApiRequest(subpath) not in self.access_api_responses:
def make_access_api_request(
self, subpath: str, params: Optional[Mapping[str, Any]] = None
) -> Mapping[str, Any]:
expected_request = ExpectedAccessApiRequest(subpath, params)
if expected_request not in self.access_api_responses:
raise Exception(
f"ExpectedAccessApiRequest({subpath}) not found in access_api_responses"
)
return self.access_api_responses[ExpectedAccessApiRequest(subpath)]
return self.access_api_responses[expected_request]

def make_discovery_api_query(
self, query: str, variables: Mapping[str, Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
class DbtCloudProjectEnvironmentData:
project_id: int
environment_id: int
# The dbt cloud job id that we'll use to kick off executions launched from a client.
job_id: int
models_by_unique_id: Mapping[str, "DbtCloudContentData"]
sources_by_unique_id: Mapping[str, "DbtCloudContentData"]
tests_by_unique_id: Mapping[str, "DbtCloudContentData"]
Expand Down
2 changes: 2 additions & 0 deletions examples/experimental/dagster-dlift/dagster_dlift/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def get_job_name(environment_id: int, project_id: int) -> str:
return f"DAGSTER_ADHOC_JOB__{project_id}__{environment_id}"
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from dagster_dlift.project import DbtCloudCredentials
from dagster_dlift.test.client_fake import (
DbtCloudClientFake,
ExpectedAccessApiRequest,
ExpectedDiscoveryApiRequest,
build_response_for_type,
)
from dagster_dlift.test.project_fake import DbtCloudProjectEnvironmentFake
from dagster_dlift.translator import DbtCloudContentType
from dagster_dlift.utils import get_job_name


def query_per_content_type(content_type: DbtCloudContentType) -> str:
Expand Down Expand Up @@ -65,9 +67,22 @@ def jaffle_shop_contents() -> (
}


def build_dagster_job_response(environment_id: int, project_id: int) -> Mapping[str, Any]:
return {"name": get_job_name(environment_id, project_id), "id": 1}


def build_expected_access_api_requests() -> Mapping[ExpectedAccessApiRequest, Any]:
return {
# List of jobs
ExpectedAccessApiRequest("/jobs/", params={"environment_id": 1}): {
"data": [build_dagster_job_response(1, 1)]
}
}


def create_seeded_jaffle_shop_client() -> DbtCloudClientFake:
return DbtCloudClientFake(
access_api_responses={},
access_api_responses=build_expected_access_api_requests(),
discovery_api_responses=build_expected_requests(dep_graph_per_type=jaffle_shop_contents()),
)

Expand All @@ -79,6 +94,6 @@ def create_jaffle_shop_project() -> DbtCloudProjectEnvironmentFake:
environment_id=1,
project_id=1,
credentials=DbtCloudCredentials(
account_id="fake", token="fake", access_url="fake", discovery_api_url="fake"
account_id=123, token="fake", access_url="fake", discovery_api_url="fake"
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def get_instance() -> DbtCloudClient:
return DbtCloudClient(
account_id=get_env_var("KS_DBT_CLOUD_ACCOUNT_ID"),
account_id=int(get_env_var("KS_DBT_CLOUD_ACCOUNT_ID")),
token=get_env_var("KS_DBT_CLOUD_TOKEN"),
access_url=get_env_var("KS_DBT_CLOUD_ACCESS_URL"),
discovery_api_url=get_env_var("KS_DBT_CLOUD_DISCOVERY_API_URL"),
Expand All @@ -15,3 +15,7 @@ def get_instance() -> DbtCloudClient:

def get_environment_id() -> int:
return get_instance().get_environment_id_by_name(TEST_ENV_NAME)


def get_project_id() -> int:
return int(get_env_var("KS_DBT_CLOUD_PROJECT_ID"))
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from dagster_dlift.client import DbtCloudClient
from dlift_kitchen_sink.instance import get_environment_id, get_instance
from dlift_kitchen_sink.instance import get_environment_id, get_instance, get_project_id


@pytest.fixture
Expand All @@ -11,3 +11,8 @@ def instance() -> DbtCloudClient:
@pytest.fixture
def environment_id() -> int:
return get_environment_id()


@pytest.fixture
def project_id() -> int:
return get_project_id()
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dagster_dlift.client import DbtCloudClient
from dagster_dlift.utils import get_job_name
from dlift_kitchen_sink.constants import EXPECTED_TAG
from dlift_kitchen_sink.instance import get_instance

Expand Down Expand Up @@ -84,3 +85,28 @@ def test_get_tests(instance: DbtCloudClient, environment_id: int) -> None:
"unique_stg_customers_customer_id",
"unique_stg_orders_order_id",
}


def test_cloud_job_apis(instance: DbtCloudClient, environment_id: int, project_id: int) -> None:
"""Tests that we can create / destroy a dagster job."""
job_id = instance.create_dagster_job(project_id, environment_id)
job_info = instance.get_job_info_by_id(job_id)
assert job_info["data"]["name"] == get_job_name(environment_id, project_id)
job_infos = instance.list_jobs(environment_id=environment_id)
assert job_id in {job_info["id"] for job_info in job_infos}

response = instance.trigger_job(job_id, steps=["dbt run --select tag:test"])
run_id = response["data"]["id"]
run_status = instance.poll_for_run_completion(run_id)
assert run_status == 10 # Indicates success
run_results = instance.get_run_results_json(run_id)
assert {result["unique_id"] for result in run_results["results"]} == {
"model.test_environment.customers",
"model.test_environment.stg_customers",
"model.test_environment.stg_orders",
}
instance.destroy_dagster_job(
project_id=project_id, environment_id=environment_id, job_id=job_id
)
job_infos = instance.list_jobs(environment_id=environment_id)
assert job_id not in {job_info["id"] for job_info in job_infos}

0 comments on commit d4fa3bc

Please sign in to comment.