From a02bfc117300ba4b2c233c29791659910331ddff Mon Sep 17 00:00:00 2001 From: Nikita Yurasov Date: Mon, 7 Oct 2024 16:00:23 +0200 Subject: [PATCH 1/5] add external retrying --- dbxio/blobs/download.py | 7 ++--- dbxio/blobs/parquet.py | 5 +++- dbxio/core/client.py | 16 +++++++++-- dbxio/core/settings.py | 18 ++++++++++++ dbxio/delta/table_commands.py | 29 ++++++------------- dbxio/sql/sql_driver.py | 21 +++++++++++--- dbxio/utils/__init__.py | 4 +-- dbxio/utils/blobs.py | 7 +++-- dbxio/utils/retries.py | 28 ++++++++++++------ dbxio/volume/volume_commands.py | 46 +++++++++++------------------- tests/conftest.py | 7 +++++ tests/test_odbc_driver.py | 12 +++++--- tests/test_statement_api_driver.py | 6 ++-- tests/test_utils_retries.py | 37 ++++++++++++------------ tests/test_volume_commands.py | 4 +-- 15 files changed, 146 insertions(+), 101 deletions(-) create mode 100644 tests/conftest.py diff --git a/dbxio/blobs/download.py b/dbxio/blobs/download.py index 8e575d4..7214e70 100644 --- a/dbxio/blobs/download.py +++ b/dbxio/blobs/download.py @@ -1,16 +1,15 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from tenacity import retry, stop_after_attempt, wait_fixed - if TYPE_CHECKING: + from dbxio.core.client import DbxIOClient from dbxio.core.cloud.client.object_storage import ObjectStorageClient -@retry(stop=stop_after_attempt(3), wait=wait_fixed(10)) def download_blob_tree( object_storage_client: 'ObjectStorageClient', local_path: Path, + client: 'DbxIOClient', prefix_path: Optional[str] = None, ): for blob in object_storage_client.list_blobs(prefix=prefix_path): @@ -32,4 +31,4 @@ def download_blob_tree( Path(local_path / relative_blob_path).mkdir(parents=True, exist_ok=True) continue - object_storage_client.download_blob_to_file(blob.name, local_path / relative_blob_path) + client.retrying(object_storage_client.download_blob_to_file, blob.name, local_path / relative_blob_path) diff --git a/dbxio/blobs/parquet.py b/dbxio/blobs/parquet.py index b557e61..631cfa7 100644 --- a/dbxio/blobs/parquet.py +++ b/dbxio/blobs/parquet.py @@ -8,6 +8,8 @@ from dbxio.sql.types import convert_dbxio_type_to_pa_type if TYPE_CHECKING: + from tenacity import Retrying + from dbxio.core.cloud.client.object_storage import ObjectStorageClient from dbxio.delta.table import Table from dbxio.delta.table_schema import TableSchema @@ -42,6 +44,7 @@ def create_tmp_parquet( data: bytes, table_identifier: Union[str, 'Table'], object_storage_client: 'ObjectStorageClient', + retrying: 'Retrying', ) -> Iterator[str]: random_part = uuid.uuid4() ti = table_identifier if isinstance(table_identifier, str) else table_identifier.table_identifier @@ -49,7 +52,7 @@ def create_tmp_parquet( str.maketrans('.!"#$%&\'()*+,/:;<=>?@[\\]^`{|}~', '______________________________') ) tmp_path = f'{translated_table_identifier}__dbxio_tmp__{random_part}.parquet' - object_storage_client.upload_blob(tmp_path, data, overwrite=True) + retrying(object_storage_client.upload_blob, tmp_path, data, overwrite=True) try: yield tmp_path finally: diff --git a/dbxio/core/client.py b/dbxio/core/client.py index 24ec542..8fba1f4 100644 --- a/dbxio/core/client.py +++ b/dbxio/core/client.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import attrs from databricks.sdk import StatementExecutionAPI, WorkspaceClient @@ -11,6 +11,10 @@ from dbxio.sql.sql_driver import SQLDriver, get_sql_driver from dbxio.utils.databricks import ClusterType from dbxio.utils.logging import get_logger +from dbxio.utils.retries import build_retrying + +if TYPE_CHECKING: + from tenacity import Retrying logger = get_logger() @@ -106,12 +110,17 @@ def workspace_api(self) -> WorkspaceClient: def statement_api(self) -> StatementExecutionAPI: return self.workspace_api.statement_execution + @property + def retrying(self) -> 'Retrying': + return build_retrying(self.settings.retry_config) + @property def _sql_driver(self) -> SQLDriver: return get_sql_driver( cluster_type=self.credential_provider.cluster_type, cluster_credentials=self._cluster_credentials, statement_api=self.statement_api, + retrying=self.retrying, session_configuration=self.session_configuration, ) @@ -134,8 +143,7 @@ def sql_to_files( Execute the SQL query and save the results to the specified directory. Returns the path to the directory with the results including the statement ID. """ - - return self._sql_driver.sql_to_files(query, results_path, max_concurrency) + return self.retrying(self._sql_driver.sql_to_files, query, results_path, max_concurrency) class DefaultDbxIOClient(DbxIOClient): @@ -146,6 +154,7 @@ class DefaultDbxIOClient(DbxIOClient): """ def __init__(self, session_configuration: Optional[Dict[str, Any]] = None): + logger.info('Creating a default client for all-purpose clusters') super().__init__( credential_provider=DefaultCredentialProvider(cluster_type=ClusterType.ALL_PURPOSE), session_configuration=session_configuration, @@ -161,6 +170,7 @@ class DefaultSqlDbxIOClient(DbxIOClient): """ def __init__(self, session_configuration: Optional[Dict[str, Any]] = None): + logger.info('Creating a default client for SQL warehouses') super().__init__( credential_provider=DefaultCredentialProvider(cluster_type=ClusterType.SQL_WAREHOUSE), session_configuration=session_configuration, diff --git a/dbxio/core/settings.py b/dbxio/core/settings.py index c2dcfde..8c58440 100644 --- a/dbxio/core/settings.py +++ b/dbxio/core/settings.py @@ -1,4 +1,5 @@ import os +from typing import Type import attrs @@ -16,10 +17,27 @@ def _cloud_provider_factory() -> CloudProvider: return CloudProvider(os.getenv(CLOUD_PROVIDER_ENV_VAR, _DEFAULT_CLOUD_PROVIDER).lower()) +@attrs.frozen +class RetryConfig: + max_attempts: int = attrs.field(default=7, validator=[attrs.validators.instance_of(int), attrs.validators.ge(1)]) + exponential_backoff_multiplier: int = attrs.field( + default=1, validator=[attrs.validators.instance_of(int), attrs.validators.ge(1)] + ) + extra_exceptions_to_retry: tuple[Type[BaseException]] = attrs.field( + factory=tuple, + validator=attrs.validators.deep_iterable( + member_validator=attrs.validators.instance_of(type), + iterable_validator=attrs.validators.instance_of(tuple), + ), + ) + + @attrs.define class Settings: cloud_provider: CloudProvider = attrs.field(factory=_cloud_provider_factory) + retry_config: RetryConfig = attrs.field(factory=RetryConfig) + @cloud_provider.validator def _validate_cloud_provider(self, attribute, value): if not isinstance(value, CloudProvider): diff --git a/dbxio/delta/table_commands.py b/dbxio/delta/table_commands.py index 4b0dc2c..dcd4d9a 100644 --- a/dbxio/delta/table_commands.py +++ b/dbxio/delta/table_commands.py @@ -17,7 +17,6 @@ from dbxio.sql.results import _FutureBaseResult from dbxio.utils.blobs import blobs_registries from dbxio.utils.logging import get_logger -from dbxio.utils.retries import dbxio_retry if TYPE_CHECKING: from dbxio.core import DbxIOClient @@ -25,7 +24,6 @@ logger = get_logger() -@dbxio_retry def exists_table(table: Union[str, Table], client: 'DbxIOClient') -> bool: """ Checks if table exists in the catalog. Tries to read one record from the table. @@ -39,7 +37,6 @@ def exists_table(table: Union[str, Table], client: 'DbxIOClient') -> bool: return False -@dbxio_retry def create_table(table: Union[str, Table], client: 'DbxIOClient') -> _FutureBaseResult: """ Creates a table in the catalog. @@ -61,7 +58,6 @@ def create_table(table: Union[str, Table], client: 'DbxIOClient') -> _FutureBase return client.sql(query) -@dbxio_retry def drop_table(table: Union[str, Table], client: 'DbxIOClient', force: bool = False) -> _FutureBaseResult: """ Drops a table from the catalog. @@ -75,7 +71,6 @@ def drop_table(table: Union[str, Table], client: 'DbxIOClient', force: bool = Fa return client.sql(drop_sql) -@dbxio_retry def read_table( table: Union[str, Table], client: 'DbxIOClient', @@ -108,7 +103,6 @@ def read_table( yield record -@dbxio_retry def save_table_to_files( table: Union[str, Table], client: 'DbxIOClient', @@ -128,7 +122,6 @@ def save_table_to_files( return client.sql_to_files(sql_read_query, results_path=results_path, max_concurrency=max_concurrency) -@dbxio_retry def write_table( table: Union[str, Table], new_records: Union[Iterator[Dict], List[Dict]], @@ -180,7 +173,6 @@ def write_table( return client.sql(_sql_query) -@dbxio_retry def copy_into_table( client: 'DbxIOClient', table: Table, @@ -208,7 +200,6 @@ def copy_into_table( client.sql(sql_copy_into_query).wait() -@dbxio_retry def bulk_write_table( table: Union[str, Table], new_records: Union[Iterator[Dict], List[Dict]], @@ -243,7 +234,12 @@ def bulk_write_table( credential_provider=client.credential_provider.az_cred_provider, ) - with create_tmp_parquet(pa_table_as_bytes, dbxio_table.table_identifier, object_storage) as tmp_path: + with create_tmp_parquet( + pa_table_as_bytes, + dbxio_table.table_identifier, + object_storage, + retrying=client.retrying, + ) as tmp_path: if not append: drop_table(dbxio_table, client=client, force=True).wait() create_table(dbxio_table, client=client).wait() @@ -258,7 +254,6 @@ def bulk_write_table( ) -@dbxio_retry def bulk_write_local_files( table: Table, path: str, @@ -285,9 +280,10 @@ def bulk_write_local_files( container_name=abs_container_name, credential_provider=client.credential_provider.az_cred_provider, ) - with blobs_registries(object_storage_client=object_storage) as (blobs, metablobs): + with blobs_registries(object_storage_client=object_storage, retrying=client.retrying) as (blobs, metablobs): for filename in files: - upload_file( + client.retrying( + upload_file, filename, # type: ignore p, object_storage_client=object_storage, @@ -315,7 +311,6 @@ def bulk_write_local_files( ) -@dbxio_retry def merge_table( table: 'Union[str , Table]', new_records: 'Union[Iterator[Dict] , List[Dict]]', @@ -354,7 +349,6 @@ def merge_table( drop_table(tmp_table, client=client, force=True).wait() -@dbxio_retry def set_comment_on_table( table: 'Union[str , Table]', comment: Union[str, None], @@ -374,7 +368,6 @@ def set_comment_on_table( return client.sql(set_comment_query) -@dbxio_retry def unset_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> _FutureBaseResult: """ Unsets the comment on a table. @@ -382,7 +375,6 @@ def unset_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') - return set_comment_on_table(table=table, comment=None, client=client) -@dbxio_retry def get_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> Union[str, None]: """ Returns the comment on a table. @@ -405,7 +397,6 @@ def get_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> return None -@dbxio_retry def set_tags_on_table(table: 'Union[str , Table]', tags: dict[str, str], client: 'DbxIOClient') -> _FutureBaseResult: """ Sets tags on a table. @@ -422,7 +413,6 @@ def set_tags_on_table(table: 'Union[str , Table]', tags: dict[str, str], client: return client.sql(set_tags_query) -@dbxio_retry def unset_tags_on_table(table: 'Union[str , Table]', tags: list[str], client: 'DbxIOClient') -> _FutureBaseResult: """ Unsets tags on a table. @@ -438,7 +428,6 @@ def unset_tags_on_table(table: 'Union[str , Table]', tags: list[str], client: 'D return client.sql(unset_tags_query) -@dbxio_retry def get_tags_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> dict[str, str]: """ Returns the tags on a table. diff --git a/dbxio/sql/sql_driver.py b/dbxio/sql/sql_driver.py index 9c8ae8b..856574e 100644 --- a/dbxio/sql/sql_driver.py +++ b/dbxio/sql/sql_driver.py @@ -1,7 +1,7 @@ from abc import abstractmethod from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import attrs from databricks import sql @@ -20,6 +20,9 @@ from dbxio.utils.databricks import ClusterType from dbxio.utils.logging import get_logger +if TYPE_CHECKING: + from tenacity import Retrying + logger = get_logger() @@ -29,6 +32,11 @@ class SQLDriver: Interface for SQL drivers. """ + @property + @abstractmethod + def retrying(self) -> 'Retrying': + raise NotImplementedError + @property @abstractmethod def cluster_type(self) -> ClusterType: @@ -107,6 +115,7 @@ class ODBCDriver(SQLDriver): cluster_type: ClusterType = attrs.field(validator=attrs.validators.instance_of(ClusterType)) cluster_credentials: ClusterCredentials = attrs.field(validator=attrs.validators.instance_of(ClusterCredentials)) + retrying: 'Retrying' session_configuration: Optional[dict[str, Any]] = attrs.field(default=None) conn: Optional[Connection] = attrs.field(default=None, init=False, repr=False) @@ -131,8 +140,7 @@ def _execute_sql_query(): if not self.conn or not self.cursor or not self.cursor.open or not self.conn.open: self.conn = sql.connect(**self.as_dict()) self.cursor = self.conn.cursor() - - self.cursor.execute(operation=built_query, parameters=built_params) + self.retrying(self.cursor.execute, operation=built_query, parameters=built_params) return self.conn, self.cursor return _FutureODBCResult(self.thread_pool.submit(_execute_sql_query)) @@ -162,6 +170,7 @@ class StatementAPIDriver(SQLDriver): cluster_type: ClusterType = attrs.field(validator=attrs.validators.instance_of(ClusterType)) cluster_credentials: ClusterCredentials = attrs.field(validator=attrs.validators.instance_of(ClusterCredentials)) statement_api: StatementExecutionAPI = attrs.field(validator=attrs.validators.instance_of(StatementExecutionAPI)) + retrying: 'Retrying' def _sql_impl(self, built_query: str, built_params: QUERY_PARAMS_TYPE) -> _FutureStatementApiResult: warehouse_id = self.cluster_credentials.http_path.split('/')[-1] @@ -171,7 +180,8 @@ def _sql_impl(self, built_query: str, built_params: QUERY_PARAMS_TYPE) -> _Futur isinstance(p, StatementParameterListItem) for p in built_params ), f'Invalid parameters types, got {built_params}' - statement_response = self.statement_api.execute_statement( + statement_response = self.retrying( + self.statement_api.execute_statement, statement=built_query, parameters=built_params, # type: ignore warehouse_id=warehouse_id, @@ -196,6 +206,7 @@ def get_sql_driver( cluster_type: ClusterType, cluster_credentials: ClusterCredentials, statement_api: StatementExecutionAPI, + retrying: 'Retrying', session_configuration: Optional[dict[str, Any]] = None, ) -> SQLDriver: """ @@ -206,12 +217,14 @@ def get_sql_driver( cluster_type=cluster_type, cluster_credentials=cluster_credentials, session_configuration=session_configuration, + retrying=retrying, ) elif cluster_type is ClusterType.SQL_WAREHOUSE: return StatementAPIDriver( cluster_type=cluster_type, cluster_credentials=cluster_credentials, statement_api=statement_api, + retrying=retrying, ) else: raise ValueError(f'Unsupported cluster type: {cluster_type}') diff --git a/dbxio/utils/__init__.py b/dbxio/utils/__init__.py index 01e7a80..1a75b24 100644 --- a/dbxio/utils/__init__.py +++ b/dbxio/utils/__init__.py @@ -9,7 +9,7 @@ ) from dbxio.utils.http import get_session from dbxio.utils.logging import get_logger -from dbxio.utils.retries import dbxio_retry +from dbxio.utils.retries import build_retrying __all__ = [ 'ClusterType', @@ -22,5 +22,5 @@ 'blobs_gc', 'get_session', 'get_logger', - 'dbxio_retry', + 'build_retrying', ] diff --git a/dbxio/utils/blobs.py b/dbxio/utils/blobs.py index 0675174..9ec15d1 100644 --- a/dbxio/utils/blobs.py +++ b/dbxio/utils/blobs.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from tenacity import Retrying + from dbxio.core.cloud.client.object_storage import ObjectStorageClient @@ -13,6 +15,7 @@ def blobs_gc(blobs: list[str], object_storage_client: 'ObjectStorageClient'): @contextmanager def blobs_registries( object_storage_client: 'ObjectStorageClient', + retrying: 'Retrying', keep_blobs: bool = False, keep_metablobs: bool = False, ): @@ -25,6 +28,6 @@ def blobs_registries( yield blobs, metablobs finally: if not keep_blobs: - blobs_gc(blobs, object_storage_client) + retrying(blobs_gc, blobs, object_storage_client) if not keep_metablobs: - blobs_gc(metablobs, object_storage_client) + retrying(blobs_gc, metablobs, object_storage_client) diff --git a/dbxio/utils/retries.py b/dbxio/utils/retries.py index 7e1b58b..af0aafd 100644 --- a/dbxio/utils/retries.py +++ b/dbxio/utils/retries.py @@ -1,12 +1,21 @@ import logging +from typing import TYPE_CHECKING +from azure.core.exceptions import AzureError +from cachetools import TTLCache, cached from databricks.sdk.errors.platform import PermissionDenied -from tenacity import RetryCallState, after_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential +from tenacity import RetryCallState, Retrying, after_log, retry_if_exception_type, stop_after_attempt, wait_exponential +from urllib3.exceptions import ReadTimeoutError from dbxio.utils.logging import get_logger +if TYPE_CHECKING: + from dbxio.core.settings import RetryConfig + logger = get_logger() +BASE_EXCEPTIONS_TO_RETRY = (PermissionDenied, ReadTimeoutError, AzureError) + def _clear_client_cache(call_state: RetryCallState) -> None: """ @@ -28,11 +37,12 @@ def _clear_client_cache(call_state: RetryCallState) -> None: return -dbxio_retry = retry( - stop=stop_after_attempt(7), - wait=wait_exponential(multiplier=1), - retry=retry_if_exception_type((PermissionDenied,)), - reraise=True, - before=_clear_client_cache, - after=after_log(logger, log_level=logging.INFO), -) +@cached(cache=TTLCache(maxsize=1024, ttl=60 * 15)) +def build_retrying(settings: 'RetryConfig') -> Retrying: + return Retrying( + stop=stop_after_attempt(settings.max_attempts), + wait=wait_exponential(multiplier=settings.exponential_backoff_multiplier), + retry=retry_if_exception_type(BASE_EXCEPTIONS_TO_RETRY + settings.extra_exceptions_to_retry), + before=_clear_client_cache, + after=after_log(logger, log_level=logging.INFO), + ) diff --git a/dbxio/volume/volume_commands.py b/dbxio/volume/volume_commands.py index 4e25f4f..d67b9ee 100644 --- a/dbxio/volume/volume_commands.py +++ b/dbxio/volume/volume_commands.py @@ -13,7 +13,6 @@ from dbxio.utils.blobs import blobs_registries from dbxio.utils.databricks import get_external_location_storage_url, get_volume_url from dbxio.utils.logging import get_logger -from dbxio.utils.retries import dbxio_retry if TYPE_CHECKING: from dbxio.core.client import DbxIOClient @@ -42,7 +41,6 @@ class Volume: storage_location: Union[str, None] = None @classmethod - @dbxio_retry def from_url(cls, url: str, client: 'DbxIOClient') -> 'Volume': """ Creates a Volume object from a URL. @@ -53,7 +51,7 @@ def from_url(cls, url: str, client: 'DbxIOClient') -> 'Volume': raise ValueError('URL must start with /Volumes/') _, catalog, schema, name, *path = url.lstrip('/').split('/') - raw_volume_info = client.workspace_api.volumes.read(f'{catalog}.{schema}.{name}') + raw_volume_info = client.retrying(client.workspace_api.volumes.read, f'{catalog}.{schema}.{name}') return cls( catalog=catalog, schema=schema, @@ -96,13 +94,12 @@ def is_external(self): return self.volume_type is VolumeType.EXTERNAL -@dbxio_retry def create_volume(volume: Volume, client: 'DbxIOClient', skip_if_exists: bool = True) -> None: if skip_if_exists and exists_volume(volume.catalog, volume.schema, volume.name, client): logger.info(f'Volume {volume.safe_full_name} already exists, skipping creation.') return - - client.workspace_api.volumes.create( + client.retrying( + client.workspace_api.volumes.create, catalog_name=volume.catalog, schema_name=volume.schema, name=volume.name, @@ -112,50 +109,47 @@ def create_volume(volume: Volume, client: 'DbxIOClient', skip_if_exists: bool = logger.info(f'Volume {volume.safe_full_name} was successfully created.') -@dbxio_retry def exists_volume(catalog_name: str, schema_name: str, volume_name: str, client: 'DbxIOClient') -> bool: - for v in client.workspace_api.volumes.list(catalog_name=catalog_name, schema_name=schema_name): + for v in client.retrying(client.workspace_api.volumes.list, catalog_name=catalog_name, schema_name=schema_name): if v.name == volume_name: return True return False -def _download_external_volume(local_path: Path, storage_location: str, volume_path: str) -> None: +def _download_external_volume(local_path: Path, storage_location: str, volume_path: str, client: 'DbxIOClient') -> None: object_storage = ObjectStorageClient.from_url(storage_location) assert object_storage.blobs_path, f'Object storage client must have a blobs path, got {object_storage=}' download_blob_tree( object_storage_client=object_storage, local_path=local_path, prefix_path=str(Path(object_storage.blobs_path) / Path(volume_path)), + client=client, ) -@dbxio_retry def _download_single_file_from_managed_volume(local_path: Path, file_path: str, client: 'DbxIOClient'): with open(local_path / Path(file_path).name, 'wb') as f: - response_content = client.workspace_api.files.download(file_path).contents + response_content = client.retrying(client.workspace_api.files.download, file_path).contents if response_content: f.write(response_content.read()) else: raise ValueError(f'Failed to download file {file_path}, got None') -@dbxio_retry def _check_if_path_is_remote_file(path: str, client: 'DbxIOClient') -> bool: try: - client.workspace_api.files.get_metadata(path) + client.retrying(client.workspace_api.files.get_metadata, path) return True except NotFound: return False -@dbxio_retry def _download_managed_volume(local_path: Path, volume: Volume, client: 'DbxIOClient'): if _check_if_path_is_remote_file(volume.mount_path, client): _download_single_file_from_managed_volume(local_path, volume.mount_path, client) return - for file in client.workspace_api.files.list_directory_contents(volume.mount_path): + for file in client.retrying(client.workspace_api.files.list_directory_contents, volume.mount_path): if file.name is None or file.path is None: raise ValueError(f'File {file} has no name or path') @@ -166,7 +160,6 @@ def _download_managed_volume(local_path: Path, volume: Volume, client: 'DbxIOCli _download_single_file_from_managed_volume(local_path, file.path, client) -@dbxio_retry def download_volume( path: Union[str, Path], catalog_name: str, @@ -200,6 +193,7 @@ def download_volume( local_path=path, storage_location=str(volume.storage_location), volume_path=volume.path, + client=client, ) elif volume.volume_type == VolumeType.MANAGED: _download_managed_volume(local_path=path, volume=volume, client=client) @@ -211,7 +205,6 @@ def download_volume( return path / volume.path -@dbxio_retry def _write_external_volume( path: Path, catalog_name: str, @@ -243,13 +236,14 @@ def _write_external_volume( object_storage_client = ObjectStorageClient.from_url(get_external_location_storage_url(catalog_name, client)) prefix_blob_path = str(Path(volume.mount_start_point) / volume_path) - with blobs_registries(object_storage_client, keep_blobs=True) as (blobs, metablobs): + with blobs_registries(object_storage_client, retrying=client.retrying, keep_blobs=True) as (blobs, metablobs): # here all files in the path, including subdirectories, are uploaded to the blob storage. # only "hidden" files (those starting with a dot) are skipped files_to_upload = path.glob('**/*') if path.is_dir() else [path] for file in files_to_upload: if file.is_file() and not file.name.startswith('.'): - upload_file( + client.retrying( + upload_file, path=file, local_path=path, prefix_blob_path=prefix_blob_path, @@ -267,7 +261,6 @@ def _write_external_volume( create_volume(volume=volume, client=client) -@dbxio_retry def _write_managed_volume( path: Path, catalog_name: str, @@ -293,14 +286,14 @@ def _write_managed_volume( if file.is_file() and not file.name.startswith('.'): file_name: Union[Path, str] = file.relative_to(path) if file != path else file.name volume_file_path = str(Path(volume.mount_path) / Path(file_name)) - client.workspace_api.files.upload( + client.retrying( + client.workspace_api.files.upload, file_path=volume_file_path, contents=file.open('rb'), overwrite=force, ) -@dbxio_retry def write_volume( path: Union[str, Path], catalog_name: str, @@ -350,7 +343,6 @@ def write_volume( ) -@dbxio_retry def set_tags_on_volume(volume: Volume, tags: dict[str, str], client: 'DbxIOClient') -> _FutureBaseResult: """ Sets tags on a volume. @@ -367,7 +359,6 @@ def set_tags_on_volume(volume: Volume, tags: dict[str, str], client: 'DbxIOClien return client.sql(set_tags_query) -@dbxio_retry def unset_tags_on_volume(volume: Volume, tags: list[str], client: 'DbxIOClient') -> _FutureBaseResult: """ Unsets tags on a volume. @@ -383,7 +374,6 @@ def unset_tags_on_volume(volume: Volume, tags: list[str], client: 'DbxIOClient') return client.sql(unset_tags_query) -@dbxio_retry def get_tags_on_volume(volume: Volume, client: 'DbxIOClient') -> dict[str, str]: """ Returns the tags on a volume. @@ -405,7 +395,6 @@ def get_tags_on_volume(volume: Volume, client: 'DbxIOClient') -> dict[str, str]: return tags -@dbxio_retry def set_comment_on_volume( volume: Volume, comment: Union[str, None], @@ -424,7 +413,6 @@ def set_comment_on_volume( return client.sql(set_comment_query) -@dbxio_retry def unset_comment_on_volume(volume: Volume, client: 'DbxIOClient') -> _FutureBaseResult: """ Unsets the comment on a volume. @@ -432,7 +420,6 @@ def unset_comment_on_volume(volume: Volume, client: 'DbxIOClient') -> _FutureBas return set_comment_on_volume(volume=volume, comment=None, client=client) -@dbxio_retry def get_comment_on_volume(volume: Volume, client: 'DbxIOClient') -> Union[str, None]: """ Returns the comment on a volume. @@ -453,7 +440,6 @@ def get_comment_on_volume(volume: Volume, client: 'DbxIOClient') -> Union[str, N return None -@dbxio_retry def drop_volume(volume: Volume, client: 'DbxIOClient', force: bool = False) -> None: """ Deletes a volume in Databricks. @@ -468,7 +454,7 @@ def drop_volume(volume: Volume, client: 'DbxIOClient', force: bool = False) -> N logger.info(f'External volume {volume.safe_full_name} was successfully cleaned up.') try: - client.workspace_api.volumes.delete(volume.full_name) + client.retrying(client.workspace_api.volumes.delete, volume.full_name) except ResourceDoesNotExist as e: if not force: raise e diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9952760 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import pytest +from tenacity import Retrying + + +@pytest.fixture +def default_retrying(): + return Retrying() diff --git a/tests/test_odbc_driver.py b/tests/test_odbc_driver.py index dad9b41..7ad4c36 100644 --- a/tests/test_odbc_driver.py +++ b/tests/test_odbc_driver.py @@ -22,11 +22,12 @@ def cluster_credentials(): ) -def test_odbc_driver_as_dict_wo_session_configuration(cluster_credentials): +def test_odbc_driver_as_dict_wo_session_configuration(cluster_credentials, default_retrying): driver = ODBCDriver( cluster_type=ClusterType.ALL_PURPOSE, cluster_credentials=cluster_credentials, session_configuration=None, + retrying=default_retrying, ) assert driver.as_dict() == { 'server_hostname': 'adb-123456789.10.azuredatabricks.net', @@ -36,11 +37,12 @@ def test_odbc_driver_as_dict_wo_session_configuration(cluster_credentials): } -def test_odbc_driver_as_dict_with_session_configuration(cluster_credentials): +def test_odbc_driver_as_dict_with_session_configuration(cluster_credentials, default_retrying): driver = ODBCDriver( cluster_type=ClusterType.ALL_PURPOSE, cluster_credentials=cluster_credentials, session_configuration={'conf_key': 'conf_value'}, + retrying=default_retrying, ) assert driver.as_dict() == { 'server_hostname': 'adb-123456789.10.azuredatabricks.net', @@ -50,22 +52,24 @@ def test_odbc_driver_as_dict_with_session_configuration(cluster_credentials): } -def test_odbc_driver_sql(cluster_credentials): +def test_odbc_driver_sql(cluster_credentials, default_retrying): with patch('databricks.sql.connect', side_effect=mock_dbx_connect): driver = ODBCDriver( cluster_type=ClusterType.ALL_PURPOSE, cluster_credentials=cluster_credentials, session_configuration=None, + retrying=default_retrying, ) data = list(driver.sql('select * from table')) assert all([not DeepDiff(row, MOCK_ROW.asDict()) for row in data]) -def test_odbc_driver_sql_to_files(cluster_credentials): +def test_odbc_driver_sql_to_files(cluster_credentials, default_retrying): driver = ODBCDriver( cluster_type=ClusterType.ALL_PURPOSE, cluster_credentials=cluster_credentials, session_configuration=None, + retrying=default_retrying, ) with tempfile.TemporaryDirectory() as temp_dir, patch('databricks.sql.connect', side_effect=mock_dbx_connect): path_to_files = driver.sql_to_files('select * from table', results_path=temp_dir) diff --git a/tests/test_statement_api_driver.py b/tests/test_statement_api_driver.py index ba932cb..a756a79 100644 --- a/tests/test_statement_api_driver.py +++ b/tests/test_statement_api_driver.py @@ -52,11 +52,12 @@ def mock_get_statement_result_chunk_n(*args, **kwargs): 'databricks.sdk.service.sql.StatementExecutionAPI.get_statement_result_chunk_n', side_effect=mock_get_statement_result_chunk_n, ) -def test_sapi_driver_sql(mock1, mock2, statement_api, cluster_credentials, requests_mock): +def test_sapi_driver_sql(mock1, mock2, statement_api, cluster_credentials, requests_mock, default_retrying): driver = StatementAPIDriver( cluster_type=ClusterType.ALL_PURPOSE, cluster_credentials=cluster_credentials, statement_api=statement_api, + retrying=default_retrying, ) with open('tests/resources/arrow_stream_1_plus_1.arrow', 'rb') as f: @@ -77,11 +78,12 @@ def test_sapi_driver_sql(mock1, mock2, statement_api, cluster_credentials, reque 'databricks.sdk.service.sql.StatementExecutionAPI.get_statement_result_chunk_n', side_effect=mock_get_statement_result_chunk_n, ) -def test_sapi_driver_sql_to_files(mock1, mock2, statement_api, cluster_credentials, requests_mock): +def test_sapi_driver_sql_to_files(mock1, mock2, statement_api, cluster_credentials, requests_mock, default_retrying): driver = StatementAPIDriver( cluster_type=ClusterType.ALL_PURPOSE, cluster_credentials=cluster_credentials, statement_api=statement_api, + retrying=default_retrying, ) with open('tests/resources/arrow_stream_1_plus_1.arrow', 'rb') as f: arrow_stream = f.read() diff --git a/tests/test_utils_retries.py b/tests/test_utils_retries.py index 4a940e2..befdab6 100644 --- a/tests/test_utils_retries.py +++ b/tests/test_utils_retries.py @@ -2,7 +2,8 @@ from databricks.sdk.errors.platform import PermissionDenied from tenacity import stop_after_attempt, wait_fixed -from dbxio import ClusterType, DbxIOClient, dbxio_retry +from dbxio import ClusterType, DbxIOClient +from dbxio.utils.retries import build_retrying from tests.mocks.azure import MockDefaultAzureCredential @@ -16,6 +17,16 @@ def mock_client(): ) +@pytest.fixture +def retrying(mock_client): + retrying = build_retrying(mock_client.settings.retry_config) + retrying.stop = stop_after_attempt(N_RETRIES) + retrying.wait = wait_fixed(0) + retrying.reraise = True + + return retrying + + N_RETRIES = 2 @@ -23,42 +34,32 @@ class UnknownException(Exception): pass -@dbxio_retry def _some_function(arg1, arg2, client: DbxIOClient, kwarg1=None, kwarg2=None): raise PermissionDenied('Permission denied') -@dbxio_retry def _some_function_with_unknown_exception(arg1, arg2, client: DbxIOClient, kwarg1=None, kwarg2=None): raise UnknownException('Unknown exception') -def test_dbxio_retry(mock_client): - func = _some_function.retry_with( - stop=stop_after_attempt(N_RETRIES), - wait=wait_fixed(0), - ) +def test_dbxio_retry(mock_client, retrying): try: - func(1, 2, mock_client, kwarg1=3, kwarg2=4) + retrying(_some_function, 1, 2, mock_client, kwarg1=3, kwarg2=4) except PermissionDenied: pass - # how to get stat: https://github.com/jd/tenacity/issues/486#issuecomment-2229210530 - attempt_number = func.retry.statistics.get('attempt_number') or func.statistics.get('attempt_number') + attempt_number = retrying.statistics.get('attempt_number') assert attempt_number == N_RETRIES -def test_dbxio_retry_unknown_exception(mock_client): - func = _some_function_with_unknown_exception.retry_with( - stop=stop_after_attempt(N_RETRIES * 100), - wait=wait_fixed(0), - ) +def test_dbxio_retry_unknown_exception(mock_client, retrying): + retrying.stop = stop_after_attempt(N_RETRIES) try: - func(1, 2, mock_client, kwarg1=3, kwarg2=4) + retrying(_some_function_with_unknown_exception, 1, 2, mock_client, kwarg1=3, kwarg2=4) except UnknownException: pass - attempt_number = func.retry.statistics.get('attempt_number') or func.statistics.get('attempt_number') + attempt_number = retrying.statistics.get('attempt_number') assert attempt_number == 1 diff --git a/tests/test_volume_commands.py b/tests/test_volume_commands.py index 40175f8..76844c5 100644 --- a/tests/test_volume_commands.py +++ b/tests/test_volume_commands.py @@ -68,7 +68,7 @@ def mock_list_directory_contents_return_values_filtered(): ) -def _mock_download_blob_tree(object_storage_client, local_path: Path, prefix_path): +def _mock_download_blob_tree(object_storage_client, local_path: Path, client, prefix_path): """ File structure: path/to/blobs dir/file @@ -218,7 +218,7 @@ def test_get_tags_on_volume(mock_sql): def test_download_external_volume(mock_download_blob_tree): storage_location = 'abfss://container@storage_account.dfs.core.windows.net/dir' with TemporaryDirectory() as temp_dir: - _download_external_volume(Path(temp_dir), storage_location, '') + _download_external_volume(Path(temp_dir), storage_location, '', client=client) assert sorted(Path(temp_dir).glob('**/*')) == sorted( [ From f88fff1fe0284699f07e68365114beb46081cc6e Mon Sep 17 00:00:00 2001 From: Nikita Yurasov Date: Mon, 7 Oct 2024 16:01:28 +0200 Subject: [PATCH 2/5] update version --- dbxio/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbxio/__init__.py b/dbxio/__init__.py index 78abcd7..3bbc900 100644 --- a/dbxio/__init__.py +++ b/dbxio/__init__.py @@ -4,4 +4,4 @@ from dbxio.utils import * # noqa: F403 from dbxio.volume import * # noqa: F403 -__version__ = '0.4.5' # single source of truth +__version__ = '0.4.6' # single source of truth From 8d85bbfeb4afcbfc44c695aba7eaecd9dd65a417 Mon Sep 17 00:00:00 2001 From: Nikita Yurasov Date: Mon, 7 Oct 2024 16:05:16 +0200 Subject: [PATCH 3/5] refactoring (mypy) --- dbxio/delta/table_commands.py | 2 +- dbxio/sql/sql_driver.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dbxio/delta/table_commands.py b/dbxio/delta/table_commands.py index dcd4d9a..045fc3d 100644 --- a/dbxio/delta/table_commands.py +++ b/dbxio/delta/table_commands.py @@ -284,7 +284,7 @@ def bulk_write_local_files( for filename in files: client.retrying( upload_file, - filename, # type: ignore + filename, p, object_storage_client=object_storage, prefix_blob_path=operation_uuid, diff --git a/dbxio/sql/sql_driver.py b/dbxio/sql/sql_driver.py index 856574e..3aea12f 100644 --- a/dbxio/sql/sql_driver.py +++ b/dbxio/sql/sql_driver.py @@ -183,7 +183,7 @@ def _sql_impl(self, built_query: str, built_params: QUERY_PARAMS_TYPE) -> _Futur statement_response = self.retrying( self.statement_api.execute_statement, statement=built_query, - parameters=built_params, # type: ignore + parameters=built_params, warehouse_id=warehouse_id, disposition=Disposition.EXTERNAL_LINKS, format=Format.ARROW_STREAM, From d2f56ef9a912db44632e0faedaa03d38dff28c12 Mon Sep 17 00:00:00 2001 From: Nikita Yurasov Date: Mon, 7 Oct 2024 16:11:32 +0200 Subject: [PATCH 4/5] refactoring --- dbxio/core/__init__.py | 3 ++- dbxio/core/settings.py | 7 ++++--- docs/README.md | 24 ++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/dbxio/core/__init__.py b/dbxio/core/__init__.py index a70a878..01a49de 100644 --- a/dbxio/core/__init__.py +++ b/dbxio/core/__init__.py @@ -8,7 +8,7 @@ DefaultCredentialProvider, ) from dbxio.core.exceptions import DbxIOTypeError, InsufficientCredentialsError, ReadDataError, UnavailableAuthError -from dbxio.core.settings import CloudProvider, Settings +from dbxio.core.settings import CloudProvider, RetryConfig, Settings __all__ = [ 'get_token', @@ -28,4 +28,5 @@ 'ReadDataError', 'Settings', 'CloudProvider', + 'RetryConfig', ] diff --git a/dbxio/core/settings.py b/dbxio/core/settings.py index 8c58440..de89c5b 100644 --- a/dbxio/core/settings.py +++ b/dbxio/core/settings.py @@ -1,5 +1,5 @@ import os -from typing import Type +from typing import Type, Union import attrs @@ -20,8 +20,9 @@ def _cloud_provider_factory() -> CloudProvider: @attrs.frozen class RetryConfig: max_attempts: int = attrs.field(default=7, validator=[attrs.validators.instance_of(int), attrs.validators.ge(1)]) - exponential_backoff_multiplier: int = attrs.field( - default=1, validator=[attrs.validators.instance_of(int), attrs.validators.ge(1)] + exponential_backoff_multiplier: Union[float, int] = attrs.field( + default=1.0, + validator=[attrs.validators.instance_of((float, int)), attrs.validators.ge(0)], ) extra_exceptions_to_retry: tuple[Type[BaseException]] = attrs.field( factory=tuple, diff --git a/docs/README.md b/docs/README.md index 6958427..17c698f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -118,6 +118,30 @@ client = dbxio.DbxIOClient.from_cluster_settings( ) ``` +### Customize retries settings + +Sometimes it's vital to retry some exceptions based on your setup. +You can customize the retry settings by passing them into the settings object. + +```python +import dbxio + +settings = dbxio.Settings( + cloud_provider=dbxio.CloudProvider.AZURE, + retry_config=dbxio.RetryConfig( + max_attempts=20, + exponential_backoff_multiplier=1.5, + extra_exceptions_to_retry=(MyCustomException,), + ), +) + +client = dbxio.DbxIOClient.from_cluster_settings( + # ..., + settings=settings, + # ..., +) +``` + ## Basic read/write table operations > [!NOTE] From fecb5e6038fa59ff00b29a3d8f89ef9c413d96b8 Mon Sep 17 00:00:00 2001 From: Nikita Yurasov Date: Mon, 7 Oct 2024 16:32:19 +0200 Subject: [PATCH 5/5] refactoring --- dbxio/utils/retries.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dbxio/utils/retries.py b/dbxio/utils/retries.py index af0aafd..59f3a5c 100644 --- a/dbxio/utils/retries.py +++ b/dbxio/utils/retries.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING from azure.core.exceptions import AzureError -from cachetools import TTLCache, cached from databricks.sdk.errors.platform import PermissionDenied from tenacity import RetryCallState, Retrying, after_log, retry_if_exception_type, stop_after_attempt, wait_exponential from urllib3.exceptions import ReadTimeoutError @@ -37,7 +36,6 @@ def _clear_client_cache(call_state: RetryCallState) -> None: return -@cached(cache=TTLCache(maxsize=1024, ttl=60 * 15)) def build_retrying(settings: 'RetryConfig') -> Retrying: return Retrying( stop=stop_after_attempt(settings.max_attempts),