From e2f6ada3fb0bc9ee98faa89b6bcc7dcc71bff22f Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 24 Jun 2024 11:17:07 +0200 Subject: [PATCH 1/4] Fixes incorrect type check in the snippet (#1424) --- .../website/docs/dlt-ecosystem/verified-sources/sql_database.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md index de3e5f4c35..fde7a64144 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md @@ -179,7 +179,7 @@ pipeline = dlt.pipeline( def _double_as_decimal_adapter(table: sa.Table) -> None: """Return double as double, not decimals, this is mysql thing""" for column in table.columns.values(): - if isinstance(column.type, sa.Double): # type: ignore + if isinstance(column.type, sa.Float): column.type.asdecimal = False sql_alchemy_source = sql_database( From 934829bb1cc2aa31da8122bdf64bcee98c6063b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Willi=20M=C3=BCller?= Date: Mon, 24 Jun 2024 14:48:40 +0530 Subject: [PATCH 2/4] #1356 implements OAuth2 Two-legged flow (#1357) --- dlt/sources/helpers/rest_client/auth.py | 86 +++++++++- .../verified-sources/rest_api.md | 1 + .../docs/general-usage/http/rest-client.md | 54 ++++++ tests/sources/helpers/rest_client/conftest.py | 41 ++++- .../helpers/rest_client/test_client.py | 158 +++++++++++++++++- 5 files changed, 320 insertions(+), 20 deletions(-) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 29e6d8c77a..d2ca1c1ca6 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,17 +1,18 @@ -from base64 import b64encode -import dataclasses import math +import dataclasses +from abc import abstractmethod +from base64 import b64encode from typing import ( - List, + TYPE_CHECKING, + Any, Dict, Final, + Iterable, + List, Literal, Optional, Union, - Any, cast, - Iterable, - TYPE_CHECKING, ) from typing_extensions import Annotated from requests.auth import AuthBase @@ -24,7 +25,6 @@ from dlt.common.configuration.specs.exceptions import NativeValueError from dlt.common.pendulum import pendulum from dlt.common.typing import TSecretStrValue - from dlt.sources.helpers import requests if TYPE_CHECKING: @@ -144,6 +144,76 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: return request +@configspec +class OAuth2ClientCredentials(OAuth2AuthBase): + """ + This class implements OAuth2 Client Credentials flow where the autorization service + gives permission without the end user approving. + This is often used for machine-to-machine authorization. + The client sends its client ID and client secret to the authorization service which replies + with a temporary access token. + With the access token, the client can access resource services. + """ + + def __init__( + self, + access_token_url: TSecretStrValue, + client_id: TSecretStrValue, + client_secret: TSecretStrValue, + access_token_request_data: Dict[str, Any] = None, + default_token_expiration: int = 3600, + session: Annotated[BaseSession, NotResolved()] = None, + ) -> None: + super().__init__() + self.access_token_url = access_token_url + self.client_id = client_id + self.client_secret = client_secret + if access_token_request_data is None: + self.access_token_request_data = {} + else: + self.access_token_request_data = access_token_request_data + self.default_token_expiration = default_token_expiration + self.token_expiry: pendulum.DateTime = pendulum.now() + + self.session = session if session is not None else requests.client.session + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + if self.access_token is None or self.is_token_expired(): + self.obtain_token() + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + def is_token_expired(self) -> bool: + return pendulum.now() >= self.token_expiry + + def obtain_token(self) -> None: + response = self.session.post(self.access_token_url, **self.build_access_token_request()) + response.raise_for_status() + response_json = response.json() + self.parse_native_representation(self.parse_access_token(response_json)) + expires_in_seconds = self.parse_expiration_in_seconds(response_json) + self.token_expiry = pendulum.now().add(seconds=expires_in_seconds) + + def build_access_token_request(self) -> Dict[str, Any]: + return { + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "client_credentials", + **self.access_token_request_data, + }, + } + + def parse_expiration_in_seconds(self, response_json: Any) -> int: + return int(response_json.get("expires_in", self.default_token_expiration)) + + def parse_access_token(self, response_json: Any) -> str: + return str(response_json.get("access_token")) + + @configspec class OAuthJWTAuth(BearerTokenAuth): """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI""" @@ -164,7 +234,7 @@ def __post_init__(self) -> None: self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) self.token = None self.token_expiry: Optional[pendulum.DateTime] = None - # use default system session is not specified + # use default system session unless specified otherwise if self.session is None: self.session = requests.client.session diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md index 98725627b9..11d09c89f7 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md @@ -416,6 +416,7 @@ Available authentication types: | [BearTokenAuth](../../general-usage/http/rest-client.md#bearer-token-authentication) | `bearer` | Bearer token authentication. | | [HTTPBasicAuth](../../general-usage/http/rest-client.md#http-basic-authentication) | `http_basic` | Basic HTTP authentication. | | [APIKeyAuth](../../general-usage/http/rest-client.md#api-key-authentication) | `api_key` | API key authentication with key defined in the query parameters or in the headers. | +| [OAuth2ClientCredentials](../../general-usage/http/rest-client.md#oauth20-authorization) | N/A | OAuth 2.0 authorization with a temporary access token obtained from the authorization server. | To specify the authentication configuration, use the `auth` field in the [client](#client) configuration: diff --git a/docs/website/docs/general-usage/http/rest-client.md b/docs/website/docs/general-usage/http/rest-client.md index 1093428b0f..3a7276a534 100644 --- a/docs/website/docs/general-usage/http/rest-client.md +++ b/docs/website/docs/general-usage/http/rest-client.md @@ -406,8 +406,11 @@ The available authentication methods are defined in the `dlt.sources.helpers.res - [BearerTokenAuth](#bearer-token-authentication) - [APIKeyAuth](#api-key-authentication) - [HttpBasicAuth](#http-basic-authentication) +- [OAuth2ClientCredentials](#oauth20-authorization) For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library. +For specific flavors of OAuth 2.0 you can [implement custom OAuth 2.0](#oauth2-authorization) +by subclassing `OAuth2ClientCredentials`. ### Bearer token authentication @@ -477,6 +480,57 @@ client = RESTClient(base_url="https://api.example.com", auth=auth) response = client.get("/protected/resource") ``` +### OAuth 2.0 authorization + +OAuth 2.0 is a common protocol for authorization. We have implemented two-legged authorization employed for server-to-server authorization because the end user (resource owner) does not need to grant approval. +The REST client acts as the OAuth client which obtains a temporary access token from the authorization server. This access token is then sent to the resource server to access protected content. If the access token is expired, the OAuth client automatically refreshes it. + +Unfortunately, most OAuth 2.0 implementations vary and thus you might need to subclass `OAuth2ClientCredentials` and implement `build_access_token_request()` to suite the requirements of the specific authorization server you want to interact with. + +**Parameters:** +- `access_token_url`: The url to obtain the temporary access token. +- `client_id`: Client credential to obtain authorization. Usually issued via a developer portal. +- `client_secret`: Client credential to obtain authorization. Usually issued via a developer portal. +- `access_token_request_data`: A dictionary with data required by the autorization server apart from the `client_id`, `client_secret`, and `"grant_type": "client_credentials"`. Defaults to `None`. +- `default_token_expiration`: The time in seconds after which the temporary access token expires. Defaults to 3600. + +**Example:** + +```py +from base64 import b64encode +from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client.auth import OAuth2ClientCredentials + +class OAuth2ClientCredentialsHTTPBasic(OAuth2ClientCredentials): + """Used e.g. by Zoom Zoom Video Communications, Inc.""" + def build_access_token_request(self) -> Dict[str, Any]: + authentication: str = b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + return { + "headers": { + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": self.access_token_request_data, + } + +auth = OAuth2ClientCredentialsHTTPBasic( + access_token_url=dlt.secrets["sources.zoom.access_token_url"], # "https://zoom.us/oauth/token" + client_id=dlt.secrets["sources.zoom.client_id"], + client_secret=dlt.secrets["sources.zoom.client_secret"], + access_token_request_data={ + "grant_type": "account_credentials", + "account_id": dlt.secrets["sources.zoom.account_id"], + }, +) +client = RESTClient(base_url="https://api.zoom.us/v2", auth=auth) + +response = client.get("/users") +``` + + + ### Implementing custom authentication You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method: diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index 7453c63d14..08233bc3a8 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -1,8 +1,8 @@ import re -from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING, Dict, List, Any +from typing import NamedTuple, Callable, Pattern, Union, TYPE_CHECKING, Dict, List, Any import base64 -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import parse_qs, urlsplit, urlunsplit import pytest import requests_mock @@ -207,7 +207,17 @@ def protected_api_key(request, context): @router.post("/oauth/token") def oauth_token(request, context): - return {"access_token": "test-token", "expires_in": 3600} + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.post("/oauth/token-expires-now") + def oauth_token_expires_now(request, context): + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 0} + context.status_code = 401 + return {"error": "Unauthorized"} @router.post("/auth/refresh") def refresh_token(request, context): @@ -217,11 +227,36 @@ def refresh_token(request, context): context.status_code = 401 return {"error": "Invalid refresh token"} + @router.post("/custom-oauth/token") + def custom_oauth_token(request, context): + qs = parse_qs(request.text) + if ( + qs.get("grant_type")[0] == "account_credentials" + and qs.get("account_id")[0] == "test-account-id" + and request.headers["Authorization"] + == "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==" + ): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + router.register_routes(m) yield m +def oauth_authorize(request): + qs = parse_qs(request.text) + grant_type = qs.get("grant_type")[0] + if "jwt-bearer" in grant_type: + return True + if "client_credentials" in grant_type: + return ( + qs["client_secret"][0] == "test-client-secret" + and qs["client_id"][0] == "test-client-id" + ) + + def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10): assert len(pages) == total_pages for i, page in enumerate(pages): diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index bd65affe62..7196ef3436 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -1,23 +1,28 @@ import os +from base64 import b64encode +from typing import Any, Dict, cast +from unittest.mock import patch + import pytest -from typing import Any, cast -from dlt.common import logger from requests import PreparedRequest, Request, Response from requests.auth import AuthBase +from requests.exceptions import HTTPError + +from dlt.common import logger from dlt.common.typing import TSecretStrValue from dlt.sources.helpers.requests import Client from dlt.sources.helpers.rest_client import RESTClient -from dlt.sources.helpers.rest_client.client import Hooks -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator - -from dlt.sources.helpers.rest_client.auth import AuthConfigBase from dlt.sources.helpers.rest_client.auth import ( - BearerTokenAuth, APIKeyAuth, + AuthConfigBase, + BearerTokenAuth, HttpBasicAuth, + OAuth2ClientCredentials, OAuthJWTAuth, ) +from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator from .conftest import assert_pagination @@ -31,13 +36,40 @@ def load_private_key(name="private_key.pem"): TEST_PRIVATE_KEY = load_private_key() -@pytest.fixture -def rest_client() -> RESTClient: +def build_rest_client(auth=None) -> RESTClient: return RESTClient( base_url="https://api.example.com", headers={"Accept": "application/json"}, session=Client().session, + auth=auth, + ) + + +@pytest.fixture +def rest_client() -> RESTClient: + return build_rest_client() + + +@pytest.fixture +def rest_client_oauth() -> RESTClient: + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, ) + return build_rest_client(auth=auth) + + +@pytest.fixture +def rest_client_immediate_oauth_expiry(auth=None) -> RESTClient: + credentials_expiring_now = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token-expires-now"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, + ) + return build_rest_client(auth=credentials_expiring_now) @pytest.mark.usefixtures("mock_api_server") @@ -163,6 +195,114 @@ def test_api_key_auth_success(self, rest_client: RESTClient): assert response.status_code == 200 assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + def test_oauth2_client_credentials_flow_auth_success(self, rest_client_oauth: RESTClient): + response = rest_client_oauth.get("/protected/posts/bearer-token") + + assert response.status_code == 200 + assert "test-token" in response.request.headers["Authorization"] + + pages_iter = rest_client_oauth.paginate("/protected/posts/bearer-token") + + assert_pagination(list(pages_iter)) + + def test_oauth2_client_credentials_flow_wrong_client_id(self, rest_client: RESTClient): + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "invalid-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, + ) + + with pytest.raises(HTTPError) as e: + rest_client.get("/protected/posts/bearer-token", auth=auth) + assert e.type == HTTPError + assert e.match("401 Client Error") + + def test_oauth2_client_credentials_flow_wrong_client_secret(self, rest_client: RESTClient): + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "invalid-client-secret"), + session=Client().session, + ) + + with pytest.raises(HTTPError) as e: + rest_client.get( + "/protected/posts/bearer-token", + auth=auth, + ) + assert e.type == HTTPError + assert e.match("401 Client Error") + + + def test_oauth_token_expired_refresh(self, rest_client_immediate_oauth_expiry: RESTClient): + rest_client = rest_client_immediate_oauth_expiry + auth = cast(OAuth2ClientCredentials, rest_client.auth) + + with patch.object(auth, "obtain_token", wraps=auth.obtain_token) as mock_obtain_token: + assert auth.access_token is None + response = rest_client.get("/protected/posts/bearer-token") + mock_obtain_token.assert_called_once() + assert response.status_code == 200 + assert auth.access_token is not None + expiry_0 = auth.token_expiry + auth.token_expiry = auth.token_expiry.subtract(seconds=1) + expiry_1 = auth.token_expiry + assert expiry_0 > expiry_1 + assert auth.is_token_expired() + + response = rest_client.get("/protected/posts/bearer-token") + assert mock_obtain_token.call_count == 2 + assert response.status_code == 200 + expiry_2 = auth.token_expiry + assert expiry_2 > expiry_1 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + def test_oauth_customized_token_request(self, rest_client: RESTClient): + class OAuth2ClientCredentialsHTTPBasic(OAuth2ClientCredentials): + """OAuth 2.0 as required by e.g. Zoom Video Communications, Inc.""" + + def build_access_token_request(self) -> Dict[str, Any]: + authentication: str = b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + return { + "headers": { + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "grant_type": "account_credentials", + **self.access_token_request_data, + }, + } + + auth = OAuth2ClientCredentialsHTTPBasic( + access_token_url=cast(TSecretStrValue, "https://api.example.com/custom-oauth/token"), + client_id=cast(TSecretStrValue, "test-account-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + access_token_request_data={ + "account_id": cast(TSecretStrValue, "test-account-id"), + }, + session=Client().session, + ) + + assert auth.build_access_token_request() == { + "headers": { + "Authorization": "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "grant_type": "account_credentials", + "account_id": "test-account-id", + }, + } + + rest_client.auth = auth + pages_iter = rest_client.paginate("/protected/posts/bearer-token") + + assert_pagination(list(pages_iter)) + def test_oauth_jwt_auth_success(self, rest_client: RESTClient): auth = OAuthJWTAuth( client_id="test-client-id", From 6466ce4060d8dc70c03c76812ad806426194fb9e Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 24 Jun 2024 11:31:33 +0200 Subject: [PATCH 3/4] Shorten the installation section in README (#1500) --- README.md | 10 +--------- docs/website/docs/reference/installation.md | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ed1cc751c2..bc0f40b62f 100644 --- a/README.md +++ b/README.md @@ -30,20 +30,12 @@ Be it a Google Colab notebook, AWS Lambda function, an Airflow DAG, your local l dlt supports Python 3.8+. -**pip:** ```sh pip install dlt ``` -**pixi:** -```sh -pixi add dlt -``` +More options: [Install via Conda or Pixi](https://dlthub.com/docs/reference/installation#install-dlt-via-pixi-and-conda) -**conda:** -```sh -conda install -c conda-forge dlt -``` ## Quick Start diff --git a/docs/website/docs/reference/installation.md b/docs/website/docs/reference/installation.md index 3f40c3a545..a23ce82c97 100644 --- a/docs/website/docs/reference/installation.md +++ b/docs/website/docs/reference/installation.md @@ -110,7 +110,7 @@ You can install `dlt` in your virtual environment by running: pip install -U dlt ``` -## Install dlt via pixi and conda +## Install dlt via Pixi and Conda Install dlt using `pixi`: From 6b83ceec9dbac809f5e154c080c6d88b2218b381 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee <34739235+Pipboyguy@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:38:35 +0200 Subject: [PATCH 4/4] Add LanceDB custom destination example code (#1323) * Add LanceDB custom destination example code Signed-off-by: Marcel Coetzee * Format Signed-off-by: Marcel Coetzee * Remove Postgres credentials from example.secrets.toml Signed-off-by: Marcel Coetzee * Format Signed-off-by: Marcel Coetzee * Add typing Signed-off-by: Marcel Coetzee * Refactor code documentation and add type ignore comments Signed-off-by: Marcel Coetzee * Ignore checks Signed-off-by: Marcel Coetzee * wrap in main if statement Signed-off-by: Marcel Coetzee * Add lancedb to install dependencies in test_doc_snippets workflow Signed-off-by: Marcel Coetzee * poetry Signed-off-by: Marcel Coetzee * Update deps Signed-off-by: Marcel Coetzee * Update LanceDB version and replace Sentence-Transformers with OpenAIEmbeddings Signed-off-by: Marcel Coetzee * Poetry lock Signed-off-by: Marcel Coetzee * Format Signed-off-by: Marcel Coetzee * Update versions Signed-off-by: Marcel Coetzee * Replace OpenAI with Cohere in LanceDB custom destination example Signed-off-by: Marcel Coetzee * Format Signed-off-by: Marcel Coetzee * Add error handling to custom destination lanceDB example Signed-off-by: Marcel Coetzee * Lift config to secrets/config Signed-off-by: Marcel Coetzee * Ignore example lancedb local dir Signed-off-by: Marcel Coetzee * Why was this uncommented Signed-off-by: Marcel Coetzee * Remove unnecessary lock Signed-off-by: Marcel Coetzee * Cleanup Signed-off-by: Marcel Coetzee * Remove print statements from custom_destination_lancedb.py Signed-off-by: Marcel Coetzee * Print info Signed-off-by: Marcel Coetzee * Print info Signed-off-by: Marcel Coetzee * Use rest_client Signed-off-by: Marcel Coetzee * noqa Signed-off-by: Marcel Coetzee * Remove `cohere` dependency and add `embeddings` extra to `lancedb` Signed-off-by: Marcel Coetzee * changing secrets path for cohere to pass docs tests * fixes lock file * moves get lancedb path to run within the test * fix dependencies * fix linting * fix lancedb deps * update lock file * change source name * moved client_id to secrets * switch lancedb example to openai and small fixes * small fixes * add openai to docs deps * fix grammar gpt typing --------- Signed-off-by: Marcel Coetzee Co-authored-by: Marcin Rudolf Co-authored-by: rahuljo Co-authored-by: Dave Co-authored-by: Alena --- .../.dlt/config.toml | 2 + .../.dlt/example.secrets.toml | 7 + .../custom_destination_lancedb/.gitignore | 1 + .../custom_destination_lancedb/__init__.py | 0 .../custom_destination_lancedb.py | 155 ++++++++++++++++++ docs/tools/fix_grammar_gpt.py | 2 +- poetry.lock | 152 ++++++++++++++++- pyproject.toml | 2 + 8 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 docs/examples/custom_destination_lancedb/.dlt/config.toml create mode 100644 docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml create mode 100644 docs/examples/custom_destination_lancedb/.gitignore create mode 100644 docs/examples/custom_destination_lancedb/__init__.py create mode 100644 docs/examples/custom_destination_lancedb/custom_destination_lancedb.py diff --git a/docs/examples/custom_destination_lancedb/.dlt/config.toml b/docs/examples/custom_destination_lancedb/.dlt/config.toml new file mode 100644 index 0000000000..4fd35e1159 --- /dev/null +++ b/docs/examples/custom_destination_lancedb/.dlt/config.toml @@ -0,0 +1,2 @@ +[lancedb] +db_path = "spotify.db" \ No newline at end of file diff --git a/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml b/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml new file mode 100644 index 0000000000..9c86df320c --- /dev/null +++ b/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml @@ -0,0 +1,7 @@ +[spotify] +client_id = "" +client_secret = "" + +# provide the openai api key here +[destination.lancedb.credentials] +embedding_model_provider_api_key = "" \ No newline at end of file diff --git a/docs/examples/custom_destination_lancedb/.gitignore b/docs/examples/custom_destination_lancedb/.gitignore new file mode 100644 index 0000000000..c73564481b --- /dev/null +++ b/docs/examples/custom_destination_lancedb/.gitignore @@ -0,0 +1 @@ +spotify.db \ No newline at end of file diff --git a/docs/examples/custom_destination_lancedb/__init__.py b/docs/examples/custom_destination_lancedb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py new file mode 100644 index 0000000000..9d75d90f99 --- /dev/null +++ b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py @@ -0,0 +1,155 @@ +""" +--- +title: Custom Destination with LanceDB +description: Learn how use the custom destination to load to LanceDB. +keywords: [destination, credentials, example, lancedb, custom destination, vectorstore, AI, LLM] +--- + +This example showcases a Python script that demonstrates the integration of LanceDB, an open-source vector database, +as a custom destination within the dlt ecosystem. +The script illustrates the implementation of a custom destination as well as the population of the LanceDB vector +store with data from various sources. +This highlights the seamless interoperability between dlt and LanceDB. + +You can get a Spotify client ID and secret from https://developer.spotify.com/. + +We'll learn how to: +- Use the [custom destination](../dlt-ecosystem/destinations/destination.md) +- Delegate the embeddings to LanceDB using OpenAI Embeddings +""" + +__source_name__ = "spotify" + +import datetime # noqa: I251 +import os +from dataclasses import dataclass, fields +from pathlib import Path +from typing import Any + +import lancedb # type: ignore +from lancedb.embeddings import get_registry # type: ignore +from lancedb.pydantic import LanceModel, Vector # type: ignore + +import dlt +from dlt.common.configuration import configspec +from dlt.common.schema import TTableSchema +from dlt.common.typing import TDataItems, TSecretStrValue +from dlt.sources.helpers import requests +from dlt.sources.helpers.rest_client import RESTClient, AuthConfigBase + +# access secrets to get openai key and instantiate embedding function +openai_api_key: str = dlt.secrets.get("destination.lancedb.credentials.embedding_model_provider_api_key") +func = get_registry().get("openai").create(name="text-embedding-3-small", api_key=openai_api_key) + + +class EpisodeSchema(LanceModel): + id: str # noqa: A003 + name: str + description: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() # type: ignore[valid-type] + release_date: datetime.date + href: str + + +@dataclass(frozen=True) +class Shows: + monday_morning_data_chat: str = "3Km3lBNzJpc1nOTJUtbtMh" + latest_space_podcast: str = "2p7zZVwVF6Yk0Zsb4QmT7t" + superdatascience_podcast: str = "1n8P7ZSgfVLVJ3GegxPat1" + lex_fridman: str = "2MAi0BvDc6GTFvKFPXnkCL" + + +@configspec +class SpotifyAuth(AuthConfigBase): + client_id: str = None + client_secret: TSecretStrValue = None + + def __call__(self, request) -> Any: + if not hasattr(self, "access_token"): + self.access_token = self._get_access_token() + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + def _get_access_token(self) -> Any: + auth_url = "https://accounts.spotify.com/api/token" + auth_response = requests.post( + auth_url, + { + "grant_type": "client_credentials", + "client_id": self.client_id, + "client_secret": self.client_secret, + }, + ) + return auth_response.json()["access_token"] + + +@dlt.source +def spotify_shows( + client_id: str = dlt.secrets.value, + client_secret: str = dlt.secrets.value, +): + spotify_base_api_url = "https://api.spotify.com/v1" + client = RESTClient( + base_url=spotify_base_api_url, + auth=SpotifyAuth(client_id=client_id, client_secret=client_secret), # type: ignore[arg-type] + ) + + for show in fields(Shows): + show_name = show.name + show_id = show.default + url = f"/shows/{show_id}/episodes" + yield dlt.resource( + client.paginate(url, params={"limit": 50}), + name=show_name, + write_disposition="merge", + primary_key="id", + parallelized=True, + max_table_nesting=0, + ) + + +@dlt.destination(batch_size=250, name="lancedb") +def lancedb_destination(items: TDataItems, table: TTableSchema) -> None: + db_path = Path(dlt.config.get("lancedb.db_path")) + db = lancedb.connect(db_path) + + # since we are embedding the description field, we need to do some additional cleaning + # for openai. Openai will not accept empty strings or input with more than 8191 tokens + for item in items: + item["description"] = item.get("description") or "No Description" + item["description"] = item["description"][0:8000] + try: + tbl = db.open_table(table["name"]) + except FileNotFoundError: + tbl = db.create_table(table["name"], schema=EpisodeSchema) + tbl.add(items) + + +if __name__ == "__main__": + db_path = Path(dlt.config.get("lancedb.db_path")) + db = lancedb.connect(db_path) + + for show in fields(Shows): + db.drop_table(show.name, ignore_missing=True) + + pipeline = dlt.pipeline( + pipeline_name="spotify", + destination=lancedb_destination, + dataset_name="spotify_podcast_data", + progress="log", + ) + + load_info = pipeline.run(spotify_shows()) + load_info.raise_on_failed_jobs() + print(load_info) + + row_counts = pipeline.last_trace.last_normalize_info + print(row_counts) + + query = "French AI scientist with Lex, talking about AGI and Meta and Llama" + table_to_query = "lex_fridman" + + tbl = db.open_table(table_to_query) + + results = tbl.search(query=query).to_list() + assert results diff --git a/docs/tools/fix_grammar_gpt.py b/docs/tools/fix_grammar_gpt.py index 065b53d470..9979a92b41 100644 --- a/docs/tools/fix_grammar_gpt.py +++ b/docs/tools/fix_grammar_gpt.py @@ -120,7 +120,7 @@ def get_chunk_length(chunk: List[str]) -> int: temperature=0, ) - fixed_chunks.append(response.choices[0].message.content) + fixed_chunks.append(response.choices[0].message.content) # type: ignore with open(file_path, "w", encoding="utf-8") as f: for c in fixed_chunks: diff --git a/poetry.lock b/poetry.lock index f6a6f98c1a..5a94993c80 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2416,6 +2416,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "diff-cover" version = "7.7.0" @@ -2450,6 +2464,17 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "dnspython" version = "2.4.2" @@ -4229,6 +4254,42 @@ completion = ["shtab"] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +[[package]] +name = "lancedb" +version = "0.6.13" +description = "lancedb" +optional = false +python-versions = ">=3.8" +files = [ + {file = "lancedb-0.6.13-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:4667353ca7fa187e94cb0ca4c5f9577d65eb5160f6f3fe9e57902d86312c3869"}, + {file = "lancedb-0.6.13-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e22533fe6f6b2d7037dcdbbb4019a62402bbad4ce18395be68f4aa007bf8bc0"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837eaceafb87e3ae4c261eef45c4f73715f892a36165572c3da621dbdb45afcf"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:61af2d72b2a2f0ea419874c3f32760fe5e51530da3be2d65251a0e6ded74419b"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:31b24e57ee313f4ce6255e45d42e8bee19b90ddcd13a9e07030ac04f76e7dfde"}, + {file = "lancedb-0.6.13-cp38-abi3-win_amd64.whl", hash = "sha256:b851182d8492b1e5b57a441af64c95da65ca30b045d6618dc7d203c6d60d70fa"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +cachetools = "*" +deprecation = "*" +overrides = ">=0.7" +pydantic = ">=1.10" +pylance = "0.10.12" +ratelimiter = ">=1.0,<2.0" +requests = ">=2.31.0" +retry = ">=0.9.2" +semver = "*" +tqdm = ">=4.27.0" + +[package.extras] +azure = ["adlfs (>=2024.2.0)"] +clip = ["open-clip", "pillow", "torch"] +dev = ["pre-commit", "ruff"] +docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] +tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] + [[package]] name = "lazy-object-proxy" version = "1.9.0" @@ -5446,6 +5507,29 @@ packaging = "*" protobuf = "*" sympy = "*" +[[package]] +name = "openai" +version = "1.35.3" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.35.3-py3-none-any.whl", hash = "sha256:7b26544cef80f125431c073ffab3811d2421fbb9e30d3bd5c2436aba00b042d5"}, + {file = "openai-1.35.3.tar.gz", hash = "sha256:d6177087f150b381d49499be782d764213fdf638d391b29ca692b84dd675a389"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "openpyxl" version = "3.1.2" @@ -5659,6 +5743,17 @@ files = [ {file = "orjson-3.9.5.tar.gz", hash = "sha256:6daf5ee0b3cf530b9978cdbf71024f1c16ed4a67d05f6ec435c6e7fe7a52724c"}, ] +[[package]] +name = "overrides" +version = "7.7.0" +description = "A decorator to automatically detect mismatch when overriding a method." +optional = false +python-versions = ">=3.6" +files = [ + {file = "overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49"}, + {file = "overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a"}, +] + [[package]] name = "packaging" version = "23.1" @@ -6559,6 +6654,32 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pylance" +version = "0.10.12" +description = "python wrapper for Lance columnar format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pylance-0.10.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:30cbcca078edeb37e11ae86cf9287d81ce6c0c07ba77239284b369a4b361497b"}, + {file = "pylance-0.10.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e558163ff6035d518706cc66848497219ccc755e2972b8f3b1706a3e1fd800fd"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75afb39f71d7f12429f9b4d380eb6cf6aed179ae5a1c5d16cc768373a1521f87"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:3de391dfc3a99bdb245fd1e27ef242be769a94853f802ef57f246e9a21358d32"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:34a5278b90f4cbcf21261353976127aa2ffbbd7d068810f0a2b0c1aa0334022a"}, + {file = "pylance-0.10.12-cp38-abi3-win_amd64.whl", hash = "sha256:6cef5975d513097fd2c22692296c9a5a138928f38d02cd34ab63a7369abc1463"}, +] + +[package.dependencies] +numpy = ">=1.22" +pyarrow = ">=12,<15.0.1" + +[package.extras] +benchmarks = ["pytest-benchmark"] +dev = ["ruff (==0.2.2)"] +ray = ["ray[data]"] +tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] +torch = ["torch"] + [[package]] name = "pymongo" version = "4.6.0" @@ -7100,6 +7221,20 @@ urllib3 = ">=1.26.14,<2.0.0" [package.extras] fastembed = ["fastembed (==0.1.1)"] +[[package]] +name = "ratelimiter" +version = "1.2.0.post0" +description = "Simple python rate limiting object" +optional = false +python-versions = "*" +files = [ + {file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"}, + {file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"}, +] + +[package.extras] +test = ["pytest (>=3.0)", "pytest-asyncio"] + [[package]] name = "redshift-connector" version = "2.0.915" @@ -7327,6 +7462,21 @@ files = [ [package.dependencies] types-setuptools = ">=57.0.0" +[[package]] +name = "retry" +version = "0.9.2" +description = "Easy to use retry decorator." +optional = false +python-versions = "*" +files = [ + {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"}, + {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"}, +] + +[package.dependencies] +decorator = ">=3.4.2" +py = ">=1.4.26,<2.0.0" + [[package]] name = "rfc3339-validator" version = "0.1.4" @@ -9088,4 +9238,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "47136cc3a6247e709dfe04a810df7309d1a2bc7fe838592dd5f58dc39c2407c8" +content-hash = "4ca5f4a7955437d6da09be909a729172b9a663cc0649227e6088dc1c2cd27e57" diff --git a/pyproject.toml b/pyproject.toml index 10e3bf47d5..b99c9e4051 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,6 +214,8 @@ pandas = ">2" alive-progress = ">=3.0.1" pyarrow = ">=14.0.0" psycopg2-binary = ">=2.9" +lancedb = ">=0.6.13" +openai = ">=1.35" [tool.black] # https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file line-length = 100