Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add external retries #19

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbxio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from dbxio.utils import * # noqa: F403
from dbxio.volume import * # noqa: F403

__version__ = '0.4.5' # single source of truth
__version__ = '0.4.6' # single source of truth
7 changes: 3 additions & 4 deletions dbxio/blobs/download.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from tenacity import retry, stop_after_attempt, wait_fixed

if TYPE_CHECKING:
from dbxio.core.client import DbxIOClient
from dbxio.core.cloud.client.object_storage import ObjectStorageClient


@retry(stop=stop_after_attempt(3), wait=wait_fixed(10))
def download_blob_tree(
object_storage_client: 'ObjectStorageClient',
local_path: Path,
client: 'DbxIOClient',
prefix_path: Optional[str] = None,
):
for blob in object_storage_client.list_blobs(prefix=prefix_path):
Expand All @@ -32,4 +31,4 @@ def download_blob_tree(
Path(local_path / relative_blob_path).mkdir(parents=True, exist_ok=True)
continue

object_storage_client.download_blob_to_file(blob.name, local_path / relative_blob_path)
client.retrying(object_storage_client.download_blob_to_file, blob.name, local_path / relative_blob_path)
5 changes: 4 additions & 1 deletion dbxio/blobs/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from dbxio.sql.types import convert_dbxio_type_to_pa_type

if TYPE_CHECKING:
from tenacity import Retrying

from dbxio.core.cloud.client.object_storage import ObjectStorageClient
from dbxio.delta.table import Table
from dbxio.delta.table_schema import TableSchema
Expand Down Expand Up @@ -42,14 +44,15 @@ def create_tmp_parquet(
data: bytes,
table_identifier: Union[str, 'Table'],
object_storage_client: 'ObjectStorageClient',
retrying: 'Retrying',
) -> Iterator[str]:
random_part = uuid.uuid4()
ti = table_identifier if isinstance(table_identifier, str) else table_identifier.table_identifier
translated_table_identifier = ti.translate(
str.maketrans('.!"#$%&\'()*+,/:;<=>?@[\\]^`{|}~', '______________________________')
)
tmp_path = f'{translated_table_identifier}__dbxio_tmp__{random_part}.parquet'
object_storage_client.upload_blob(tmp_path, data, overwrite=True)
retrying(object_storage_client.upload_blob, tmp_path, data, overwrite=True)
try:
yield tmp_path
finally:
Expand Down
3 changes: 2 additions & 1 deletion dbxio/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
DefaultCredentialProvider,
)
from dbxio.core.exceptions import DbxIOTypeError, InsufficientCredentialsError, ReadDataError, UnavailableAuthError
from dbxio.core.settings import CloudProvider, Settings
from dbxio.core.settings import CloudProvider, RetryConfig, Settings

__all__ = [
'get_token',
Expand All @@ -28,4 +28,5 @@
'ReadDataError',
'Settings',
'CloudProvider',
'RetryConfig',
]
16 changes: 13 additions & 3 deletions dbxio/core/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import attrs
from databricks.sdk import StatementExecutionAPI, WorkspaceClient
Expand All @@ -11,6 +11,10 @@
from dbxio.sql.sql_driver import SQLDriver, get_sql_driver
from dbxio.utils.databricks import ClusterType
from dbxio.utils.logging import get_logger
from dbxio.utils.retries import build_retrying

if TYPE_CHECKING:
from tenacity import Retrying

logger = get_logger()

Expand Down Expand Up @@ -106,12 +110,17 @@ def workspace_api(self) -> WorkspaceClient:
def statement_api(self) -> StatementExecutionAPI:
return self.workspace_api.statement_execution

@property
def retrying(self) -> 'Retrying':
return build_retrying(self.settings.retry_config)

@property
def _sql_driver(self) -> SQLDriver:
return get_sql_driver(
cluster_type=self.credential_provider.cluster_type,
cluster_credentials=self._cluster_credentials,
statement_api=self.statement_api,
retrying=self.retrying,
session_configuration=self.session_configuration,
)

Expand All @@ -134,8 +143,7 @@ def sql_to_files(
Execute the SQL query and save the results to the specified directory.
Returns the path to the directory with the results including the statement ID.
"""

return self._sql_driver.sql_to_files(query, results_path, max_concurrency)
return self.retrying(self._sql_driver.sql_to_files, query, results_path, max_concurrency)


class DefaultDbxIOClient(DbxIOClient):
Expand All @@ -146,6 +154,7 @@ class DefaultDbxIOClient(DbxIOClient):
"""

def __init__(self, session_configuration: Optional[Dict[str, Any]] = None):
logger.info('Creating a default client for all-purpose clusters')
super().__init__(
credential_provider=DefaultCredentialProvider(cluster_type=ClusterType.ALL_PURPOSE),
session_configuration=session_configuration,
Expand All @@ -161,6 +170,7 @@ class DefaultSqlDbxIOClient(DbxIOClient):
"""

def __init__(self, session_configuration: Optional[Dict[str, Any]] = None):
logger.info('Creating a default client for SQL warehouses')
super().__init__(
credential_provider=DefaultCredentialProvider(cluster_type=ClusterType.SQL_WAREHOUSE),
session_configuration=session_configuration,
Expand Down
19 changes: 19 additions & 0 deletions dbxio/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Type, Union

import attrs

Expand All @@ -16,10 +17,28 @@ def _cloud_provider_factory() -> CloudProvider:
return CloudProvider(os.getenv(CLOUD_PROVIDER_ENV_VAR, _DEFAULT_CLOUD_PROVIDER).lower())


@attrs.frozen
class RetryConfig:
max_attempts: int = attrs.field(default=7, validator=[attrs.validators.instance_of(int), attrs.validators.ge(1)])
exponential_backoff_multiplier: Union[float, int] = attrs.field(
default=1.0,
validator=[attrs.validators.instance_of((float, int)), attrs.validators.ge(0)],
)
extra_exceptions_to_retry: tuple[Type[BaseException]] = attrs.field(
factory=tuple,
validator=attrs.validators.deep_iterable(
member_validator=attrs.validators.instance_of(type),
iterable_validator=attrs.validators.instance_of(tuple),
),
)


@attrs.define
class Settings:
cloud_provider: CloudProvider = attrs.field(factory=_cloud_provider_factory)

retry_config: RetryConfig = attrs.field(factory=RetryConfig)

@cloud_provider.validator
def _validate_cloud_provider(self, attribute, value):
if not isinstance(value, CloudProvider):
Expand Down
31 changes: 10 additions & 21 deletions dbxio/delta/table_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
from dbxio.sql.results import _FutureBaseResult
from dbxio.utils.blobs import blobs_registries
from dbxio.utils.logging import get_logger
from dbxio.utils.retries import dbxio_retry

if TYPE_CHECKING:
from dbxio.core import DbxIOClient

logger = get_logger()


@dbxio_retry
def exists_table(table: Union[str, Table], client: 'DbxIOClient') -> bool:
"""
Checks if table exists in the catalog. Tries to read one record from the table.
Expand All @@ -39,7 +37,6 @@ def exists_table(table: Union[str, Table], client: 'DbxIOClient') -> bool:
return False


@dbxio_retry
def create_table(table: Union[str, Table], client: 'DbxIOClient') -> _FutureBaseResult:
"""
Creates a table in the catalog.
Expand All @@ -61,7 +58,6 @@ def create_table(table: Union[str, Table], client: 'DbxIOClient') -> _FutureBase
return client.sql(query)


@dbxio_retry
def drop_table(table: Union[str, Table], client: 'DbxIOClient', force: bool = False) -> _FutureBaseResult:
"""
Drops a table from the catalog.
Expand All @@ -75,7 +71,6 @@ def drop_table(table: Union[str, Table], client: 'DbxIOClient', force: bool = Fa
return client.sql(drop_sql)


@dbxio_retry
def read_table(
table: Union[str, Table],
client: 'DbxIOClient',
Expand Down Expand Up @@ -108,7 +103,6 @@ def read_table(
yield record


@dbxio_retry
def save_table_to_files(
table: Union[str, Table],
client: 'DbxIOClient',
Expand All @@ -128,7 +122,6 @@ def save_table_to_files(
return client.sql_to_files(sql_read_query, results_path=results_path, max_concurrency=max_concurrency)


@dbxio_retry
def write_table(
table: Union[str, Table],
new_records: Union[Iterator[Dict], List[Dict]],
Expand Down Expand Up @@ -180,7 +173,6 @@ def write_table(
return client.sql(_sql_query)


@dbxio_retry
def copy_into_table(
client: 'DbxIOClient',
table: Table,
Expand Down Expand Up @@ -208,7 +200,6 @@ def copy_into_table(
client.sql(sql_copy_into_query).wait()


@dbxio_retry
def bulk_write_table(
table: Union[str, Table],
new_records: Union[Iterator[Dict], List[Dict]],
Expand Down Expand Up @@ -243,7 +234,12 @@ def bulk_write_table(
credential_provider=client.credential_provider.az_cred_provider,
)

with create_tmp_parquet(pa_table_as_bytes, dbxio_table.table_identifier, object_storage) as tmp_path:
with create_tmp_parquet(
pa_table_as_bytes,
dbxio_table.table_identifier,
object_storage,
retrying=client.retrying,
) as tmp_path:
if not append:
drop_table(dbxio_table, client=client, force=True).wait()
create_table(dbxio_table, client=client).wait()
Expand All @@ -258,7 +254,6 @@ def bulk_write_table(
)


@dbxio_retry
def bulk_write_local_files(
table: Table,
path: str,
Expand All @@ -285,10 +280,11 @@ def bulk_write_local_files(
container_name=abs_container_name,
credential_provider=client.credential_provider.az_cred_provider,
)
with blobs_registries(object_storage_client=object_storage) as (blobs, metablobs):
with blobs_registries(object_storage_client=object_storage, retrying=client.retrying) as (blobs, metablobs):
for filename in files:
upload_file(
filename, # type: ignore
client.retrying(
upload_file,
filename,
p,
object_storage_client=object_storage,
prefix_blob_path=operation_uuid,
Expand All @@ -315,7 +311,6 @@ def bulk_write_local_files(
)


@dbxio_retry
def merge_table(
table: 'Union[str , Table]',
new_records: 'Union[Iterator[Dict] , List[Dict]]',
Expand Down Expand Up @@ -354,7 +349,6 @@ def merge_table(
drop_table(tmp_table, client=client, force=True).wait()


@dbxio_retry
def set_comment_on_table(
table: 'Union[str , Table]',
comment: Union[str, None],
Expand All @@ -374,15 +368,13 @@ def set_comment_on_table(
return client.sql(set_comment_query)


@dbxio_retry
def unset_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> _FutureBaseResult:
"""
Unsets the comment on a table.
"""
return set_comment_on_table(table=table, comment=None, client=client)


@dbxio_retry
def get_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> Union[str, None]:
"""
Returns the comment on a table.
Expand All @@ -405,7 +397,6 @@ def get_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') ->
return None


@dbxio_retry
def set_tags_on_table(table: 'Union[str , Table]', tags: dict[str, str], client: 'DbxIOClient') -> _FutureBaseResult:
"""
Sets tags on a table.
Expand All @@ -422,7 +413,6 @@ def set_tags_on_table(table: 'Union[str , Table]', tags: dict[str, str], client:
return client.sql(set_tags_query)


@dbxio_retry
def unset_tags_on_table(table: 'Union[str , Table]', tags: list[str], client: 'DbxIOClient') -> _FutureBaseResult:
"""
Unsets tags on a table.
Expand All @@ -438,7 +428,6 @@ def unset_tags_on_table(table: 'Union[str , Table]', tags: list[str], client: 'D
return client.sql(unset_tags_query)


@dbxio_retry
def get_tags_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> dict[str, str]:
"""
Returns the tags on a table.
Expand Down
Loading