Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-airbyte] Support client_id and client_secret in AirbyteCloudResource #23451

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions docs/content/integrations/airbyte-cloud.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ To get started, you will need to install the `dagster` and `dagster-airbyte` Pyt
pip install dagster dagster-airbyte
```

You'll also need to have an Airbyte Cloud account, and have created an Airbyte API Key. For more information, see the [Airbyte API docs](https://reference.airbyte.com/reference/start).
You'll also need to have an Airbyte Cloud account, and have created an Airbyte client ID and client secret. For more information, see the [Airbyte API docs](https://reference.airbyte.com/reference/getting-started) and [Airbyte authentication guide](https://reference.airbyte.com/reference/authentication).

---

Expand All @@ -58,11 +58,12 @@ from dagster import EnvVar
from dagster_airbyte import AirbyteCloudResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
```

Here, the API key is provided using an <PyObject object="EnvVar" />. For more information on setting environment variables in a production setting, see [Using environment variables and secrets](/guides/dagster/using-environment-variables-and-secrets).
Here, the client ID and client secret are provided using an <PyObject object="EnvVar" />. For more information on setting environment variables in a production setting, see [Using environment variables and secrets](/guides/dagster/using-environment-variables-and-secrets).

---

Expand Down Expand Up @@ -104,7 +105,8 @@ from dagster_airbyte import build_airbyte_assets, AirbyteCloudResource
from dagster import Definitions, EnvVar

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -153,7 +155,8 @@ from dagster_snowflake_pandas import SnowflakePandasIOManager
import pandas as pd

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -207,7 +210,8 @@ from dagster_airbyte import (
from dagster_snowflake import SnowflakeResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -261,7 +265,8 @@ from dagster import (
)

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def scope_define_cloud_instance() -> None:
from dagster_airbyte import AirbyteCloudResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
# end_define_cloud_instance

Expand Down Expand Up @@ -126,7 +127,8 @@ def scope_airbyte_cloud_manual_config():
from dagster import Definitions, EnvVar

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -257,7 +259,8 @@ def scope_add_downstream_assets_cloud():
import pandas as pd

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -310,7 +313,8 @@ def scope_add_downstream_assets_cloud_with_deps():
from dagster_snowflake import SnowflakeResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -400,7 +404,8 @@ def scope_schedule_assets_cloud():
)

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import time
from abc import abstractmethod
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Dict, List, Mapping, Optional, cast

import requests
from dagster import (
ConfigurableResource,
Failure,
InitResourceContext,
_check as check,
get_dagster_logger,
resource,
Expand All @@ -19,13 +21,17 @@
from dagster._core.definitions.resource_definition import dagster_maintained_resource
from dagster._utils.cached_method import cached_method
from dagster._utils.merger import deep_merge_dicts
from pydantic import Field
from pydantic import Field, PrivateAttr
from requests.exceptions import RequestException

from dagster_airbyte.types import AirbyteOutput

DEFAULT_POLL_INTERVAL_SECONDS = 10

# The access token expire every 3 minutes in Airbyte Cloud.
# Refresh is needed after 2.5 minutes to avoid the "token expired" error message.
AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS = 150


class AirbyteState:
RUNNING = "running"
Expand Down Expand Up @@ -94,7 +100,11 @@ def all_additional_request_params(self) -> Mapping[str, Any]:
raise NotImplementedError()

def make_request(
self, endpoint: str, data: Optional[Mapping[str, object]] = None, method: str = "POST"
self,
endpoint: str,
data: Optional[Mapping[str, object]] = None,
method: str = "POST",
include_additional_request_params: bool = True,
) -> Optional[Mapping[str, object]]:
"""Creates and sends a request to the desired Airbyte REST API endpoint.

Expand All @@ -120,10 +130,11 @@ def make_request(
if data:
request_args["json"] = data

request_args = deep_merge_dicts(
request_args,
self.all_additional_request_params,
)
if include_additional_request_params:
request_args = deep_merge_dicts(
request_args,
self.all_additional_request_params,
)

response = requests.request(
**request_args,
Expand Down Expand Up @@ -244,7 +255,7 @@ def sync_and_poll(


class AirbyteCloudResource(BaseAirbyteResource):
"""This resource allows users to programatically interface with the Airbyte Cloud API to launch
"""This resource allows users to programmatically interface with the Airbyte Cloud API to launch
syncs and monitor their progress.

**Examples:**
Expand All @@ -255,7 +266,8 @@ class AirbyteCloudResource(BaseAirbyteResource):
from dagster_airbyte import AirbyteResource

my_airbyte_resource = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)

airbyte_assets = build_airbyte_assets(
Expand All @@ -269,15 +281,48 @@ class AirbyteCloudResource(BaseAirbyteResource):
)
"""

api_key: str = Field(..., description="The Airbyte Cloud API key.")
client_id: str = Field(..., description="The Airbyte Cloud client ID.")
client_secret: str = Field(..., description="The Airbyte Cloud client secret.")

_access_token_value: Optional[str] = PrivateAttr(default=None)
_access_token_timestamp: Optional[float] = PrivateAttr(default=None)

def setup_for_execution(self, context: InitResourceContext) -> None:
# Refresh access token when the resource is initialized
self._refresh_access_token()

@property
def api_base_url(self) -> str:
return "https://api.airbyte.com/v1"

@property
def all_additional_request_params(self) -> Mapping[str, Any]:
return {"headers": {"Authorization": f"Bearer {self.api_key}", "User-Agent": "dagster"}}
# Make sure the access token is refreshed before using it when calling the API.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a @property, we should avoid doing real work inside the fn body (like refreshing the access token). One different approach could be be to pull out the refresh into a separate fn like _get_or_refresh_access_token() so that it's clear that invoking it may have a cost, and call that directly in make_request

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to make_request in f2fefcb

if self._needs_refreshed_access_token():
self._refresh_access_token()
return {
"headers": {
"Authorization": f"Bearer {self._access_token_value}",
"User-Agent": "dagster",
}
}

def make_request(
self,
endpoint: str,
data: Optional[Mapping[str, object]] = None,
method: str = "POST",
include_additional_request_params: bool = True,
) -> Optional[Mapping[str, object]]:
# Make sure the access token is refreshed before using it when calling the API.
if include_additional_request_params and self._needs_refreshed_access_token():
self._refresh_access_token()
return super().make_request(
endpoint=endpoint,
data=data,
method=method,
include_additional_request_params=include_additional_request_params,
)

def start_sync(self, connection_id: str) -> Mapping[str, object]:
job_sync = check.not_none(
Expand Down Expand Up @@ -306,6 +351,31 @@ def _should_forward_logs(self) -> bool:
# Airbyte Cloud does not support streaming logs yet
return False

def _refresh_access_token(self) -> None:
response = check.not_none(
self.make_request(
endpoint="/applications/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
},
# Must not pass the bearer access token when refreshing it.
include_additional_request_params=False,
)
)
self._access_token_value = str(response["access_token"])
self._access_token_timestamp = datetime.now().timestamp()

def _needs_refreshed_access_token(self) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behavior makes sense to me, this is low enough complexity that it feels ok to abstract from the user

return (
not self._access_token_value
or not self._access_token_timestamp
or self._access_token_timestamp
<= datetime.timestamp(
datetime.now() - timedelta(seconds=AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS)
)
)


class AirbyteResource(BaseAirbyteResource):
"""This resource allows users to programatically interface with the Airbyte REST API to launch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def test_assets_with_normalization(


def test_assets_cloud() -> None:
ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0)
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)
ab_url = ab_resource.api_base_url

ab_assets = build_airbyte_assets(
Expand All @@ -220,6 +222,11 @@ def test_assets_cloud() -> None:
)

with responses.RequestsMock() as rsps:
rsps.add(
rsps.POST,
f"{ab_url}/applications/token",
json={"access_token": "some_access_token"},
)
rsps.add(
rsps.POST,
f"{ab_url}/jobs",
Expand Down
Loading