Skip to content

Commit

Permalink
Merge branch 'devel' into rfix/allows-naming-conventions
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Jun 24, 2024
2 parents ab69b76 + 6b83cee commit 0eeb21d
Show file tree
Hide file tree
Showing 16 changed files with 642 additions and 33 deletions.
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,12 @@ Be it a Google Colab notebook, AWS Lambda function, an Airflow DAG, your local l

dlt supports Python 3.8+.

**pip:**
```sh
pip install dlt
```

**pixi:**
```sh
pixi add dlt
```
More options: [Install via Conda or Pixi](https://dlthub.com/docs/reference/installation#install-dlt-via-pixi-and-conda)

**conda:**
```sh
conda install -c conda-forge dlt
```

## Quick Start

Expand Down
86 changes: 78 additions & 8 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from base64 import b64encode
import dataclasses
import math
import dataclasses
from abc import abstractmethod
from base64 import b64encode
from typing import (
List,
TYPE_CHECKING,
Any,
Dict,
Final,
Iterable,
List,
Literal,
Optional,
Union,
Any,
cast,
Iterable,
TYPE_CHECKING,
)
from typing_extensions import Annotated
from requests.auth import AuthBase
Expand All @@ -24,7 +25,6 @@
from dlt.common.configuration.specs.exceptions import NativeValueError
from dlt.common.pendulum import pendulum
from dlt.common.typing import TSecretStrValue

from dlt.sources.helpers import requests

if TYPE_CHECKING:
Expand Down Expand Up @@ -144,6 +144,76 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
return request


@configspec
class OAuth2ClientCredentials(OAuth2AuthBase):
"""
This class implements OAuth2 Client Credentials flow where the autorization service
gives permission without the end user approving.
This is often used for machine-to-machine authorization.
The client sends its client ID and client secret to the authorization service which replies
with a temporary access token.
With the access token, the client can access resource services.
"""

def __init__(
self,
access_token_url: TSecretStrValue,
client_id: TSecretStrValue,
client_secret: TSecretStrValue,
access_token_request_data: Dict[str, Any] = None,
default_token_expiration: int = 3600,
session: Annotated[BaseSession, NotResolved()] = None,
) -> None:
super().__init__()
self.access_token_url = access_token_url
self.client_id = client_id
self.client_secret = client_secret
if access_token_request_data is None:
self.access_token_request_data = {}
else:
self.access_token_request_data = access_token_request_data
self.default_token_expiration = default_token_expiration
self.token_expiry: pendulum.DateTime = pendulum.now()

self.session = session if session is not None else requests.client.session

def __call__(self, request: PreparedRequest) -> PreparedRequest:
if self.access_token is None or self.is_token_expired():
self.obtain_token()
request.headers["Authorization"] = f"Bearer {self.access_token}"
return request

def is_token_expired(self) -> bool:
return pendulum.now() >= self.token_expiry

def obtain_token(self) -> None:
response = self.session.post(self.access_token_url, **self.build_access_token_request())
response.raise_for_status()
response_json = response.json()
self.parse_native_representation(self.parse_access_token(response_json))
expires_in_seconds = self.parse_expiration_in_seconds(response_json)
self.token_expiry = pendulum.now().add(seconds=expires_in_seconds)

def build_access_token_request(self) -> Dict[str, Any]:
return {
"headers": {
"Content-Type": "application/x-www-form-urlencoded",
},
"data": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "client_credentials",
**self.access_token_request_data,
},
}

def parse_expiration_in_seconds(self, response_json: Any) -> int:
return int(response_json.get("expires_in", self.default_token_expiration))

def parse_access_token(self, response_json: Any) -> str:
return str(response_json.get("access_token"))


@configspec
class OAuthJWTAuth(BearerTokenAuth):
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""
Expand All @@ -164,7 +234,7 @@ def __post_init__(self) -> None:
self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None
# use default system session is not specified
# use default system session unless specified otherwise
if self.session is None:
self.session = requests.client.session

Expand Down
2 changes: 2 additions & 0 deletions docs/examples/custom_destination_lancedb/.dlt/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[lancedb]
db_path = "spotify.db"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[spotify]
client_id = ""
client_secret = ""

# provide the openai api key here
[destination.lancedb.credentials]
embedding_model_provider_api_key = ""
1 change: 1 addition & 0 deletions docs/examples/custom_destination_lancedb/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
spotify.db
Empty file.
155 changes: 155 additions & 0 deletions docs/examples/custom_destination_lancedb/custom_destination_lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
---
title: Custom Destination with LanceDB
description: Learn how use the custom destination to load to LanceDB.
keywords: [destination, credentials, example, lancedb, custom destination, vectorstore, AI, LLM]
---
This example showcases a Python script that demonstrates the integration of LanceDB, an open-source vector database,
as a custom destination within the dlt ecosystem.
The script illustrates the implementation of a custom destination as well as the population of the LanceDB vector
store with data from various sources.
This highlights the seamless interoperability between dlt and LanceDB.
You can get a Spotify client ID and secret from https://developer.spotify.com/.
We'll learn how to:
- Use the [custom destination](../dlt-ecosystem/destinations/destination.md)
- Delegate the embeddings to LanceDB using OpenAI Embeddings
"""

__source_name__ = "spotify"

import datetime # noqa: I251
import os
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any

import lancedb # type: ignore
from lancedb.embeddings import get_registry # type: ignore
from lancedb.pydantic import LanceModel, Vector # type: ignore

import dlt
from dlt.common.configuration import configspec
from dlt.common.schema import TTableSchema
from dlt.common.typing import TDataItems, TSecretStrValue
from dlt.sources.helpers import requests
from dlt.sources.helpers.rest_client import RESTClient, AuthConfigBase

# access secrets to get openai key and instantiate embedding function
openai_api_key: str = dlt.secrets.get("destination.lancedb.credentials.embedding_model_provider_api_key")
func = get_registry().get("openai").create(name="text-embedding-3-small", api_key=openai_api_key)


class EpisodeSchema(LanceModel):
id: str # noqa: A003
name: str
description: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField() # type: ignore[valid-type]
release_date: datetime.date
href: str


@dataclass(frozen=True)
class Shows:
monday_morning_data_chat: str = "3Km3lBNzJpc1nOTJUtbtMh"
latest_space_podcast: str = "2p7zZVwVF6Yk0Zsb4QmT7t"
superdatascience_podcast: str = "1n8P7ZSgfVLVJ3GegxPat1"
lex_fridman: str = "2MAi0BvDc6GTFvKFPXnkCL"


@configspec
class SpotifyAuth(AuthConfigBase):
client_id: str = None
client_secret: TSecretStrValue = None

def __call__(self, request) -> Any:
if not hasattr(self, "access_token"):
self.access_token = self._get_access_token()
request.headers["Authorization"] = f"Bearer {self.access_token}"
return request

def _get_access_token(self) -> Any:
auth_url = "https://accounts.spotify.com/api/token"
auth_response = requests.post(
auth_url,
{
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
},
)
return auth_response.json()["access_token"]


@dlt.source
def spotify_shows(
client_id: str = dlt.secrets.value,
client_secret: str = dlt.secrets.value,
):
spotify_base_api_url = "https://api.spotify.com/v1"
client = RESTClient(
base_url=spotify_base_api_url,
auth=SpotifyAuth(client_id=client_id, client_secret=client_secret), # type: ignore[arg-type]
)

for show in fields(Shows):
show_name = show.name
show_id = show.default
url = f"/shows/{show_id}/episodes"
yield dlt.resource(
client.paginate(url, params={"limit": 50}),
name=show_name,
write_disposition="merge",
primary_key="id",
parallelized=True,
max_table_nesting=0,
)


@dlt.destination(batch_size=250, name="lancedb")
def lancedb_destination(items: TDataItems, table: TTableSchema) -> None:
db_path = Path(dlt.config.get("lancedb.db_path"))
db = lancedb.connect(db_path)

# since we are embedding the description field, we need to do some additional cleaning
# for openai. Openai will not accept empty strings or input with more than 8191 tokens
for item in items:
item["description"] = item.get("description") or "No Description"
item["description"] = item["description"][0:8000]
try:
tbl = db.open_table(table["name"])
except FileNotFoundError:
tbl = db.create_table(table["name"], schema=EpisodeSchema)
tbl.add(items)


if __name__ == "__main__":
db_path = Path(dlt.config.get("lancedb.db_path"))
db = lancedb.connect(db_path)

for show in fields(Shows):
db.drop_table(show.name, ignore_missing=True)

pipeline = dlt.pipeline(
pipeline_name="spotify",
destination=lancedb_destination,
dataset_name="spotify_podcast_data",
progress="log",
)

load_info = pipeline.run(spotify_shows())
load_info.raise_on_failed_jobs()
print(load_info)

row_counts = pipeline.last_trace.last_normalize_info
print(row_counts)

query = "French AI scientist with Lex, talking about AGI and Meta and Llama"
table_to_query = "lex_fridman"

tbl = db.open_table(table_to_query)

results = tbl.search(query=query).to_list()
assert results
2 changes: 1 addition & 1 deletion docs/tools/fix_grammar_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_chunk_length(chunk: List[str]) -> int:
temperature=0,
)

fixed_chunks.append(response.choices[0].message.content)
fixed_chunks.append(response.choices[0].message.content) # type: ignore

with open(file_path, "w", encoding="utf-8") as f:
for c in fixed_chunks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ Available authentication types:
| [BearTokenAuth](../../general-usage/http/rest-client.md#bearer-token-authentication) | `bearer` | Bearer token authentication. |
| [HTTPBasicAuth](../../general-usage/http/rest-client.md#http-basic-authentication) | `http_basic` | Basic HTTP authentication. |
| [APIKeyAuth](../../general-usage/http/rest-client.md#api-key-authentication) | `api_key` | API key authentication with key defined in the query parameters or in the headers. |
| [OAuth2ClientCredentials](../../general-usage/http/rest-client.md#oauth20-authorization) | N/A | OAuth 2.0 authorization with a temporary access token obtained from the authorization server. |

To specify the authentication configuration, use the `auth` field in the [client](#client) configuration:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ pipeline = dlt.pipeline(
def _double_as_decimal_adapter(table: sa.Table) -> None:
"""Return double as double, not decimals, this is mysql thing"""
for column in table.columns.values():
if isinstance(column.type, sa.Double): # type: ignore
if isinstance(column.type, sa.Float):
column.type.asdecimal = False

sql_alchemy_source = sql_database(
Expand Down
Loading

0 comments on commit 0eeb21d

Please sign in to comment.