From 89b99995d568fe077b1617bc634bb21b9b5c9de8 Mon Sep 17 00:00:00 2001 From: Maxime Armstrong <46797220+maximearmstrong@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:41:32 -0400 Subject: [PATCH] [dagster-airbyte] Support `client_id` and `client_secret` in AirbyteCloudResource (#23451) ## Summary & Motivation This PR updates the `AirbyteCloudResource` to support `client_id` and `client_secret` for authentication. The `api_key` can no longer be used for authentication because Airbyte is [deprecating portal.airbyte.com](https://reference.airbyte.com/reference/portalairbytecom-deprecation). Tests and docs are updated to reflect the changes. Two main questions for reviewers: - this PR implements a pattern where the access token is refreshed before making a call to the API, if the token was fetched more than 2.5 minutes ago. Should we avoid this pattern and let users manage the resource lifecycle? - [According to Airbyte](https://reference.airbyte.com/reference/portalairbytecom-deprecation), the access token expires after 3 minutes. - The access token is initially fetched in `setup_for_execution`, then refreshed if needed. I'm concerned that for jobs including other assets, it might take more than 3 minutes before the Airbyte assets are materialized. - [portal.airbyte.com will be deprecated next week](https://reference.airbyte.com/reference/portalairbytecom-deprecation), on August 15th, so I removed the previous `api_key` attribute without deprecation warning. Are we comfortable doing so? Considering that this approach will fail next week. ## How I Tested These Changes BK with updated tests Fully tested on a live cloud instance with this code: ```python from dagster import Definitions, EnvVar from dagster_airbyte import AirbyteCloudResource, build_airbyte_assets airbyte_instance = AirbyteCloudResource( client_id=EnvVar("AIRBYTE_CLIENT_ID"), client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( # Test connection - Sample Data (Faker) to Google Sheets connection_id="0bb7a00c-0b85-4fac-b8ff-67dc380f1c29", destination_tables=["products", "purchases", "users"], ) defs = Definitions(assets=airbyte_assets, resources={"airbyte": airbyte_instance}) ``` Job successful: Screenshot 2024-08-13 at 4 30 12 PM Asset graph: Screenshot 2024-08-13 at 4 30 38 PM --- docs/content/integrations/airbyte-cloud.mdx | 19 ++-- .../integrations/airbyte/airbyte.py | 15 ++- .../dagster_airbyte/resources.py | 90 +++++++++++++-- .../dagster_airbyte_tests/test_asset_defs.py | 9 +- .../test_cloud_resources.py | 106 +++++++++++++++++- .../test_load_from_instance.py | 4 +- .../dagster_airbyte_tests/test_ops.py | 23 +++- 7 files changed, 235 insertions(+), 31 deletions(-) diff --git a/docs/content/integrations/airbyte-cloud.mdx b/docs/content/integrations/airbyte-cloud.mdx index 46b31893cd3a5..f32e07f1e7575 100644 --- a/docs/content/integrations/airbyte-cloud.mdx +++ b/docs/content/integrations/airbyte-cloud.mdx @@ -45,7 +45,7 @@ To get started, you will need to install the `dagster` and `dagster-airbyte` Pyt pip install dagster dagster-airbyte ``` -You'll also need to have an Airbyte Cloud account, and have created an Airbyte API Key. For more information, see the [Airbyte API docs](https://reference.airbyte.com/reference/start). +You'll also need to have an Airbyte Cloud account, and have created an Airbyte client ID and client secret. For more information, see the [Airbyte API docs](https://reference.airbyte.com/reference/getting-started) and [Airbyte authentication guide](https://reference.airbyte.com/reference/authentication). --- @@ -58,11 +58,12 @@ from dagster import EnvVar from dagster_airbyte import AirbyteCloudResource airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) ``` -Here, the API key is provided using an . For more information on setting environment variables in a production setting, see [Using environment variables and secrets](/guides/dagster/using-environment-variables-and-secrets). +Here, the client ID and client secret are provided using an . For more information on setting environment variables in a production setting, see [Using environment variables and secrets](/guides/dagster/using-environment-variables-and-secrets). --- @@ -104,7 +105,8 @@ from dagster_airbyte import build_airbyte_assets, AirbyteCloudResource from dagster import Definitions, EnvVar airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", @@ -153,7 +155,8 @@ from dagster_snowflake_pandas import SnowflakePandasIOManager import pandas as pd airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", @@ -207,7 +210,8 @@ from dagster_airbyte import ( from dagster_snowflake import SnowflakeResource airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", @@ -261,7 +265,8 @@ from dagster import ( ) airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", diff --git a/examples/docs_snippets/docs_snippets/integrations/airbyte/airbyte.py b/examples/docs_snippets/docs_snippets/integrations/airbyte/airbyte.py index 00254a2851cdd..4f69a699ae955 100644 --- a/examples/docs_snippets/docs_snippets/integrations/airbyte/airbyte.py +++ b/examples/docs_snippets/docs_snippets/integrations/airbyte/airbyte.py @@ -22,7 +22,8 @@ def scope_define_cloud_instance() -> None: from dagster_airbyte import AirbyteCloudResource airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) # end_define_cloud_instance @@ -126,7 +127,8 @@ def scope_airbyte_cloud_manual_config(): from dagster import Definitions, EnvVar airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", @@ -257,7 +259,8 @@ def scope_add_downstream_assets_cloud(): import pandas as pd airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", @@ -310,7 +313,8 @@ def scope_add_downstream_assets_cloud_with_deps(): from dagster_snowflake import SnowflakeResource airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", @@ -400,7 +404,8 @@ def scope_schedule_assets_cloud(): ) airbyte_instance = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( connection_id="43908042-8399-4a58-82f1-71a45099fff7", diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py index 9a1ac1cf81ae9..0da3f1f061c18 100644 --- a/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py @@ -5,12 +5,14 @@ import time from abc import abstractmethod from contextlib import contextmanager +from datetime import datetime, timedelta from typing import Any, Dict, List, Mapping, Optional, cast import requests from dagster import ( ConfigurableResource, Failure, + InitResourceContext, _check as check, get_dagster_logger, resource, @@ -19,13 +21,17 @@ from dagster._core.definitions.resource_definition import dagster_maintained_resource from dagster._utils.cached_method import cached_method from dagster._utils.merger import deep_merge_dicts -from pydantic import Field +from pydantic import Field, PrivateAttr from requests.exceptions import RequestException from dagster_airbyte.types import AirbyteOutput DEFAULT_POLL_INTERVAL_SECONDS = 10 +# The access token expire every 3 minutes in Airbyte Cloud. +# Refresh is needed after 2.5 minutes to avoid the "token expired" error message. +AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS = 150 + class AirbyteState: RUNNING = "running" @@ -94,7 +100,11 @@ def all_additional_request_params(self) -> Mapping[str, Any]: raise NotImplementedError() def make_request( - self, endpoint: str, data: Optional[Mapping[str, object]] = None, method: str = "POST" + self, + endpoint: str, + data: Optional[Mapping[str, object]] = None, + method: str = "POST", + include_additional_request_params: bool = True, ) -> Optional[Mapping[str, object]]: """Creates and sends a request to the desired Airbyte REST API endpoint. @@ -120,10 +130,11 @@ def make_request( if data: request_args["json"] = data - request_args = deep_merge_dicts( - request_args, - self.all_additional_request_params, - ) + if include_additional_request_params: + request_args = deep_merge_dicts( + request_args, + self.all_additional_request_params, + ) response = requests.request( **request_args, @@ -244,7 +255,7 @@ def sync_and_poll( class AirbyteCloudResource(BaseAirbyteResource): - """This resource allows users to programatically interface with the Airbyte Cloud API to launch + """This resource allows users to programmatically interface with the Airbyte Cloud API to launch syncs and monitor their progress. **Examples:** @@ -255,7 +266,8 @@ class AirbyteCloudResource(BaseAirbyteResource): from dagster_airbyte import AirbyteResource my_airbyte_resource = AirbyteCloudResource( - api_key=EnvVar("AIRBYTE_API_KEY"), + client_id=EnvVar("AIRBYTE_CLIENT_ID"), + client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"), ) airbyte_assets = build_airbyte_assets( @@ -269,7 +281,15 @@ class AirbyteCloudResource(BaseAirbyteResource): ) """ - api_key: str = Field(..., description="The Airbyte Cloud API key.") + client_id: str = Field(..., description="The Airbyte Cloud client ID.") + client_secret: str = Field(..., description="The Airbyte Cloud client secret.") + + _access_token_value: Optional[str] = PrivateAttr(default=None) + _access_token_timestamp: Optional[float] = PrivateAttr(default=None) + + def setup_for_execution(self, context: InitResourceContext) -> None: + # Refresh access token when the resource is initialized + self._refresh_access_token() @property def api_base_url(self) -> str: @@ -277,7 +297,32 @@ def api_base_url(self) -> str: @property def all_additional_request_params(self) -> Mapping[str, Any]: - return {"headers": {"Authorization": f"Bearer {self.api_key}", "User-Agent": "dagster"}} + # Make sure the access token is refreshed before using it when calling the API. + if self._needs_refreshed_access_token(): + self._refresh_access_token() + return { + "headers": { + "Authorization": f"Bearer {self._access_token_value}", + "User-Agent": "dagster", + } + } + + def make_request( + self, + endpoint: str, + data: Optional[Mapping[str, object]] = None, + method: str = "POST", + include_additional_request_params: bool = True, + ) -> Optional[Mapping[str, object]]: + # Make sure the access token is refreshed before using it when calling the API. + if include_additional_request_params and self._needs_refreshed_access_token(): + self._refresh_access_token() + return super().make_request( + endpoint=endpoint, + data=data, + method=method, + include_additional_request_params=include_additional_request_params, + ) def start_sync(self, connection_id: str) -> Mapping[str, object]: job_sync = check.not_none( @@ -306,6 +351,31 @@ def _should_forward_logs(self) -> bool: # Airbyte Cloud does not support streaming logs yet return False + def _refresh_access_token(self) -> None: + response = check.not_none( + self.make_request( + endpoint="/applications/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + }, + # Must not pass the bearer access token when refreshing it. + include_additional_request_params=False, + ) + ) + self._access_token_value = str(response["access_token"]) + self._access_token_timestamp = datetime.now().timestamp() + + def _needs_refreshed_access_token(self) -> bool: + return ( + not self._access_token_value + or not self._access_token_timestamp + or self._access_token_timestamp + <= datetime.timestamp( + datetime.now() - timedelta(seconds=AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS) + ) + ) + class AirbyteResource(BaseAirbyteResource): """This resource allows users to programatically interface with the Airbyte REST API to launch diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_asset_defs.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_asset_defs.py index 87aef59373a55..024465c027a7d 100644 --- a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_asset_defs.py +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_asset_defs.py @@ -208,7 +208,9 @@ def test_assets_with_normalization( def test_assets_cloud() -> None: - ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0) + ab_resource = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret", poll_interval=0 + ) ab_url = ab_resource.api_base_url ab_assets = build_airbyte_assets( @@ -220,6 +222,11 @@ def test_assets_cloud() -> None: ) with responses.RequestsMock() as rsps: + rsps.add( + rsps.POST, + f"{ab_url}/applications/token", + json={"access_token": "some_access_token"}, + ) rsps.add( rsps.POST, f"{ab_url}/jobs", diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_cloud_resources.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_cloud_resources.py index d93f1f353b5e4..8d48369cdb124 100644 --- a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_cloud_resources.py +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_cloud_resources.py @@ -1,4 +1,7 @@ +import datetime +import json import re +from unittest import mock import pytest import responses @@ -8,7 +11,14 @@ @responses.activate def test_trigger_connection() -> None: - ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0) + ab_resource = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret", poll_interval=0 + ) + responses.add( + responses.POST, + f"{ab_resource.api_base_url}/applications/token", + json={"access_token": "some_access_token"}, + ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs", @@ -19,8 +29,17 @@ def test_trigger_connection() -> None: assert resp == {"job": {"id": 1, "status": "pending"}} +@responses.activate def test_trigger_connection_fail() -> None: - ab_resource = AirbyteCloudResource(api_key="some_key") + ab_resource = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret" + ) + responses.add( + responses.POST, + f"{ab_resource.api_base_url}/applications/token", + json={"access_token": "some_access_token"}, + ) + with pytest.raises( Failure, match=re.escape("Max retries (3) exceeded with url: https://api.airbyte.com/v1/jobs."), @@ -34,7 +53,15 @@ def test_trigger_connection_fail() -> None: [AirbyteState.SUCCEEDED, AirbyteState.CANCELLED, AirbyteState.ERROR, "unrecognized"], ) def test_sync_and_poll(state) -> None: - ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0) + ab_resource = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret", poll_interval=0 + ) + + responses.add( + responses.POST, + f"{ab_resource.api_base_url}/applications/token", + json={"access_token": "some_access_token"}, + ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs", @@ -78,8 +105,15 @@ def test_sync_and_poll(state) -> None: @responses.activate def test_start_sync_bad_out_fail() -> None: - ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0) + ab_resource = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret", poll_interval=0 + ) + responses.add( + responses.POST, + f"{ab_resource.api_base_url}/applications/token", + json={"access_token": "some_access_token"}, + ) responses.add( method=responses.POST, url=ab_resource.api_base_url + "/jobs", @@ -91,3 +125,67 @@ def test_start_sync_bad_out_fail() -> None: match=re.escape("Max retries (3) exceeded with url: https://api.airbyte.com/v1/jobs."), ): ab_resource.start_sync("some_connection") + + +@responses.activate +def test_refresh_access_token() -> None: + ab_resource = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret", poll_interval=0 + ) + responses.add( + responses.POST, + f"{ab_resource.api_base_url}/applications/token", + json={"access_token": "some_access_token"}, + ) + responses.add( + method=responses.POST, + url=ab_resource.api_base_url + "/jobs", + json={"jobId": 1, "status": "pending", "jobType": "sync"}, + status=200, + ) + + test_time_first_call = datetime.datetime(2024, 1, 1, 0, 0, 0) + test_time_before_expiration = datetime.datetime(2024, 1, 1, 0, 2, 0) + test_time_after_expiration = datetime.datetime(2024, 1, 1, 0, 3, 0) + with mock.patch("dagster_airbyte.resources.datetime", wraps=datetime.datetime) as dt: + # Test first call, must get the access token before calling the jobs api + dt.now.return_value = test_time_first_call + ab_resource.start_sync("some_connection") + + assert len(responses.calls) == 2 + access_token_call = responses.calls[0] + jobs_api_call = responses.calls[1] + + assert "Authorization" not in access_token_call.request.headers + access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8")) + assert access_token_call_body["client_id"] == "some_client_id" + assert access_token_call_body["client_secret"] == "some_client_secret" + assert jobs_api_call.request.headers["Authorization"] == "Bearer some_access_token" + + responses.calls.reset() + + # Test second call, occurs before the access token expiration, only the jobs api is called + dt.now.return_value = test_time_before_expiration + ab_resource.start_sync("some_connection") + + assert len(responses.calls) == 1 + jobs_api_call = responses.calls[0] + + assert jobs_api_call.request.headers["Authorization"] == "Bearer some_access_token" + + responses.calls.reset() + + # Test third call, occurs after the token expiration, + # must refresh the access token before calling the jobs api + dt.now.return_value = test_time_after_expiration + ab_resource.start_sync("some_connection") + + assert len(responses.calls) == 2 + access_token_call = responses.calls[0] + jobs_api_call = responses.calls[1] + + assert "Authorization" not in access_token_call.request.headers + access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8")) + assert access_token_call_body["client_id"] == "some_client_id" + assert access_token_call_body["client_secret"] == "some_client_secret" + assert jobs_api_call.request.headers["Authorization"] == "Bearer some_access_token" diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_load_from_instance.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_load_from_instance.py index 285b792bec328..941b939d2222e 100644 --- a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_load_from_instance.py +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_load_from_instance.py @@ -288,7 +288,9 @@ def downstream_asset(dagster_tags): def test_load_from_instance_cloud() -> None: - airbyte_cloud_instance = AirbyteCloudResource(api_key="foo", poll_interval=0) + airbyte_cloud_instance = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret", poll_interval=0 + ) with pytest.raises( DagsterInvalidInvocationError, diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_ops.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_ops.py index 929a825213040..080c3f72e7b26 100644 --- a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_ops.py +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/test_ops.py @@ -1,3 +1,4 @@ +import json from base64 import b64encode import pytest @@ -102,7 +103,9 @@ def airbyte_sync_job(): def test_airbyte_sync_op_cloud() -> None: - ab_resource = AirbyteCloudResource(api_key="some_key") + ab_resource = AirbyteCloudResource( + client_id="some_client_id", client_secret="some_client_secret" + ) ab_url = ab_resource.api_base_url @op @@ -127,6 +130,11 @@ def airbyte_sync_job() -> None: airbyte_sync_op(start_after=foo_op()) with responses.RequestsMock() as rsps: + rsps.add( + rsps.POST, + f"{ab_url}/applications/token", + json={"access_token": "some_access_token"}, + ) rsps.add( rsps.POST, f"{ab_url}/jobs", @@ -150,5 +158,14 @@ def airbyte_sync_job() -> None: connection_details={}, ) - for call in rsps.calls: - assert call.request.headers["Authorization"] == "Bearer some_key" + # The first call is to get the access token. + access_token_call = rsps.calls[0] + api_calls = rsps.calls[1:] + + assert "Authorization" not in access_token_call.request.headers + access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8")) + assert access_token_call_body["client_id"] == "some_client_id" + assert access_token_call_body["client_secret"] == "some_client_secret" + + for call in api_calls: + assert call.request.headers["Authorization"] == "Bearer some_access_token"