Skip to content

Commit

Permalink
Enable S3 compatible storage for delta table format (#1586)
Browse files Browse the repository at this point in the history
* handle credentials for s3 compatible storage

* fix delta table test

* add missing filesystem driver in skip message
  • Loading branch information
jorritsandbrink authored Jul 18, 2024
1 parent 080478a commit 9db6650
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 59 deletions.
36 changes: 29 additions & 7 deletions dlt/common/configuration/specs/aws_credentials.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, cast

from dlt.common.utils import without_none
from dlt.common.exceptions import MissingDependencyException
from dlt.common.typing import TSecretStrValue, DictStrAny
from dlt.common.configuration.specs import (
CredentialsConfiguration,
CredentialsWithDefault,
configspec,
)
from dlt.common.configuration.specs.exceptions import InvalidBoto3Session
from dlt.common.configuration.specs.exceptions import (
InvalidBoto3Session,
ObjectStoreRsCredentialsException,
)
from dlt import version


Expand Down Expand Up @@ -47,11 +51,29 @@ def to_session_credentials(self) -> Dict[str, str]:

def to_object_store_rs_credentials(self) -> Dict[str, str]:
# https://docs.rs/object_store/latest/object_store/aws
assert self.region_name is not None, "`object_store` Rust crate requires AWS region."
creds = self.to_session_credentials()
if creds["aws_session_token"] is None:
creds.pop("aws_session_token")
return {**creds, **{"region": self.region_name}}
creds = cast(
Dict[str, str],
without_none(
dict(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region=self.region_name,
endpoint_url=self.endpoint_url,
)
),
)

if "endpoint_url" not in creds: # AWS S3
if "region" not in creds:
raise ObjectStoreRsCredentialsException(
"`object_store` Rust crate requires AWS region when using AWS S3."
)
else: # S3-compatible, e.g. MinIO
if self.endpoint_url.startswith("http://"):
creds["aws_allow_http"] = "true"

return creds


@configspec
Expand Down
4 changes: 4 additions & 0 deletions dlt/common/configuration/specs/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ def __init__(self, spec: Type[Any], native_value: Any):
" containing credentials"
)
super().__init__(spec, native_value, msg)


class ObjectStoreRsCredentialsException(ConfigurationException):
pass
94 changes: 63 additions & 31 deletions tests/load/filesystem/test_object_store_rs_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
GcpServiceAccountCredentialsWithoutDefaults,
GcpOAuthCredentialsWithoutDefaults,
)
from dlt.common.configuration.specs.exceptions import ObjectStoreRsCredentialsException

from tests.load.utils import (
AZ_BUCKET,
AWS_BUCKET,
GCS_BUCKET,
R2_BUCKET_CONFIG,
ALL_FILESYSTEM_DRIVERS,
)

from tests.load.utils import AZ_BUCKET, AWS_BUCKET, GCS_BUCKET, ALL_FILESYSTEM_DRIVERS

if all(driver not in ALL_FILESYSTEM_DRIVERS for driver in ("az", "s3", "gs")):
if all(driver not in ALL_FILESYSTEM_DRIVERS for driver in ("az", "s3", "gs", "r2")):
pytest.skip(
"Requires at least one of `az`, `s3`, `gs` in `ALL_FILESYSTEM_DRIVERS`.",
"Requires at least one of `az`, `s3`, `gs`, `r2` in `ALL_FILESYSTEM_DRIVERS`.",
allow_module_level=True,
)

Expand Down Expand Up @@ -53,10 +60,10 @@ def can_connect(bucket_url: str, object_store_rs_credentials: Dict[str, str]) ->
return False


@pytest.mark.skipif(
"az" not in ALL_FILESYSTEM_DRIVERS, reason="`az` not in `ALL_FILESYSTEM_DRIVERS`"
@pytest.mark.parametrize(
"driver", [driver for driver in ALL_FILESYSTEM_DRIVERS if driver in ("az")]
)
def test_azure_object_store_rs_credentials() -> None:
def test_azure_object_store_rs_credentials(driver: str) -> None:
creds: AnyAzureCredentials

creds = AzureServicePrincipalCredentialsWithoutDefaults(
Expand All @@ -78,66 +85,91 @@ def test_azure_object_store_rs_credentials() -> None:
assert can_connect(AZ_BUCKET, creds.to_object_store_rs_credentials())


@pytest.mark.skipif(
"s3" not in ALL_FILESYSTEM_DRIVERS, reason="`s3` not in `ALL_FILESYSTEM_DRIVERS`"
@pytest.mark.parametrize(
"driver", [driver for driver in ALL_FILESYSTEM_DRIVERS if driver in ("s3", "r2")]
)
def test_aws_object_store_rs_credentials() -> None:
def test_aws_object_store_rs_credentials(driver: str) -> None:
creds: AwsCredentialsWithoutDefaults

fs_creds = FS_CREDS
if driver == "r2":
fs_creds = R2_BUCKET_CONFIG["credentials"] # type: ignore[assignment]

# AwsCredentialsWithoutDefaults: no user-provided session token
creds = AwsCredentialsWithoutDefaults(
aws_access_key_id=fs_creds["aws_access_key_id"],
aws_secret_access_key=fs_creds["aws_secret_access_key"],
region_name=fs_creds.get("region_name"),
endpoint_url=fs_creds.get("endpoint_url"),
)
assert creds.aws_session_token is None
object_store_rs_creds = creds.to_object_store_rs_credentials()
assert "aws_session_token" not in object_store_rs_creds # no auto-generated token
assert can_connect(AWS_BUCKET, object_store_rs_creds)

# AwsCredentials: no user-provided session token
creds = AwsCredentials(
aws_access_key_id=FS_CREDS["aws_access_key_id"],
aws_secret_access_key=FS_CREDS["aws_secret_access_key"],
# region_name must be configured in order for data lake to work
region_name=FS_CREDS["region_name"],
aws_access_key_id=fs_creds["aws_access_key_id"],
aws_secret_access_key=fs_creds["aws_secret_access_key"],
region_name=fs_creds.get("region_name"),
endpoint_url=fs_creds.get("endpoint_url"),
)
assert creds.aws_session_token is None
object_store_rs_creds = creds.to_object_store_rs_credentials()
assert object_store_rs_creds["aws_session_token"] is not None # auto-generated token
assert "aws_session_token" not in object_store_rs_creds # no auto-generated token
assert can_connect(AWS_BUCKET, object_store_rs_creds)

# exception should be raised if both `endpoint_url` and `region_name` are
# not provided
with pytest.raises(ObjectStoreRsCredentialsException):
AwsCredentials(
aws_access_key_id=fs_creds["aws_access_key_id"],
aws_secret_access_key=fs_creds["aws_secret_access_key"],
).to_object_store_rs_credentials()

if "endpoint_url" in object_store_rs_creds:
# TODO: make sure this case is tested on GitHub CI, e.g. by adding
# a local MinIO bucket to the set of tested buckets
if object_store_rs_creds["endpoint_url"].startswith("http://"):
assert object_store_rs_creds["aws_allow_http"] == "true"

# remainder of tests use session tokens
# we don't run them on S3 compatible storage because session tokens
# may not be available
return

# AwsCredentials: user-provided session token
# use previous credentials to create session token for new credentials
assert isinstance(creds, AwsCredentials)
sess_creds = creds.to_session_credentials()
creds = AwsCredentials(
aws_access_key_id=sess_creds["aws_access_key_id"],
aws_secret_access_key=cast(TSecretStrValue, sess_creds["aws_secret_access_key"]),
aws_session_token=cast(TSecretStrValue, sess_creds["aws_session_token"]),
region_name=FS_CREDS["region_name"],
region_name=fs_creds["region_name"],
)
assert creds.aws_session_token is not None
object_store_rs_creds = creds.to_object_store_rs_credentials()
assert object_store_rs_creds["aws_session_token"] is not None
assert can_connect(AWS_BUCKET, object_store_rs_creds)

# AwsCredentialsWithoutDefaults: no user-provided session token
creds = AwsCredentialsWithoutDefaults(
aws_access_key_id=FS_CREDS["aws_access_key_id"],
aws_secret_access_key=FS_CREDS["aws_secret_access_key"],
region_name=FS_CREDS["region_name"],
)
assert creds.aws_session_token is None
object_store_rs_creds = creds.to_object_store_rs_credentials()
assert "aws_session_token" not in object_store_rs_creds # no auto-generated token
assert can_connect(AWS_BUCKET, object_store_rs_creds)

# AwsCredentialsWithoutDefaults: user-provided session token
creds = AwsCredentialsWithoutDefaults(
aws_access_key_id=sess_creds["aws_access_key_id"],
aws_secret_access_key=cast(TSecretStrValue, sess_creds["aws_secret_access_key"]),
aws_session_token=cast(TSecretStrValue, sess_creds["aws_session_token"]),
region_name=FS_CREDS["region_name"],
region_name=fs_creds["region_name"],
)
assert creds.aws_session_token is not None
object_store_rs_creds = creds.to_object_store_rs_credentials()
assert object_store_rs_creds["aws_session_token"] is not None
assert can_connect(AWS_BUCKET, object_store_rs_creds)


@pytest.mark.skipif(
"gs" not in ALL_FILESYSTEM_DRIVERS, reason="`gs` not in `ALL_FILESYSTEM_DRIVERS`"
@pytest.mark.parametrize(
"driver", [driver for driver in ALL_FILESYSTEM_DRIVERS if driver in ("gs")]
)
def test_gcp_object_store_rs_credentials() -> None:
def test_gcp_object_store_rs_credentials(driver) -> None:
creds = GcpServiceAccountCredentialsWithoutDefaults(
project_id=FS_CREDS["project_id"],
private_key=FS_CREDS["private_key"],
Expand Down
33 changes: 12 additions & 21 deletions tests/load/pipeline/test_filesystem_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,7 @@ def some_source():


@pytest.mark.essential
def test_delta_table_core(
default_buckets_env: str,
local_filesystem_pipeline: dlt.Pipeline,
) -> None:
def test_delta_table_core(default_buckets_env: str) -> None:
"""Tests core functionality for `delta` table format.
Tests all data types, all filesystems, all write dispositions.
Expand All @@ -253,8 +250,10 @@ def data_types():
nonlocal row
yield [row] * 10

pipeline = dlt.pipeline(pipeline_name="fs_pipe", destination="filesystem", dev_mode=True)

# run pipeline, this should create Delta table
info = local_filesystem_pipeline.run(data_types())
info = pipeline.run(data_types())
assert_load_info(info)

# `delta` table format should use `parquet` file format
Expand All @@ -266,37 +265,29 @@ def data_types():

# 10 rows should be loaded to the Delta table and the content of the first
# row should match expected values
rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[
"data_types"
]
rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"]
assert len(rows) == 10
assert_all_data_types_row(rows[0], schema=column_schemas)

# another run should append rows to the table
info = local_filesystem_pipeline.run(data_types())
info = pipeline.run(data_types())
assert_load_info(info)
rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[
"data_types"
]
rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"]
assert len(rows) == 20

# ensure "replace" write disposition is handled
# should do logical replace, increasing the table version
info = local_filesystem_pipeline.run(data_types(), write_disposition="replace")
info = pipeline.run(data_types(), write_disposition="replace")
assert_load_info(info)
client = cast(FilesystemClient, local_filesystem_pipeline.destination_client())
client = cast(FilesystemClient, pipeline.destination_client())
assert _get_delta_table(client, "data_types").version() == 2
rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[
"data_types"
]
rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"]
assert len(rows) == 10

# `merge` resolves to `append` behavior
info = local_filesystem_pipeline.run(data_types(), write_disposition="merge")
info = pipeline.run(data_types(), write_disposition="merge")
assert_load_info(info)
rows = load_tables_to_dicts(local_filesystem_pipeline, "data_types", exclude_system_cols=True)[
"data_types"
]
rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"]
assert len(rows) == 20


Expand Down

0 comments on commit 9db6650

Please sign in to comment.