diff --git a/.github/workflows/test_destination_motherduck.yml b/.github/workflows/test_destination_motherduck.yml new file mode 100644 index 0000000000..a51fb3cc8f --- /dev/null +++ b/.github/workflows/test_destination_motherduck.yml @@ -0,0 +1,80 @@ + +name: dest | motherduck + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + schedule: + - cron: '0 2 * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + + ACTIVE_DESTINATIONS: "[\"motherduck\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} + + run_loader: + name: dest | motherduck tests + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + defaults: + run: + shell: bash + runs-on: "ubuntu-latest" + + steps: + + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-motherduck + + - name: Install dependencies + run: poetry install --no-interaction -E motherduck -E s3 -E gs -E az -E parquet --with sentry-sdk --with pipeline + + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + - run: | + poetry run pytest tests/load -m "essential" + name: Run essential tests Linux + if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} + + - run: | + poetry run pytest tests/load + name: Run all tests Linux + if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index e75cd6c780..a034ac7eb0 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -28,7 +28,7 @@ env: RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} # Test redshift and filesystem with all buckets # postgres runs again here so we can test on mac/windows - ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\", \"motherduck\"]" + ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\"]" jobs: get_docs_changes: diff --git a/Makefile b/Makefile index 15fb895a9f..f47047a3fe 100644 --- a/Makefile +++ b/Makefile @@ -52,7 +52,7 @@ lint: poetry run mypy --config-file mypy.ini dlt tests poetry run flake8 --max-line-length=200 dlt poetry run flake8 --max-line-length=200 tests --exclude tests/reflection/module_cases - poetry run black dlt docs tests --diff --extend-exclude=".*syntax_error.py" + poetry run black dlt docs tests --check --diff --color --extend-exclude=".*syntax_error.py" # poetry run isort ./ --diff # $(MAKE) lint-security diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index ca9d6a2d94..3af7dcff13 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -27,6 +27,7 @@ from dlt.common import logger from dlt.common.configuration.specs.base_configuration import extract_inner_hint from dlt.common.destination.utils import verify_schema_capabilities +from dlt.common.exceptions import TerminalValueError from dlt.common.normalizers.naming import NamingConvention from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.utils import ( @@ -42,6 +43,8 @@ InvalidDestinationReference, UnknownDestinationModule, DestinationSchemaTampered, + DestinationTransientException, + DestinationTerminalException, ) from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage @@ -187,6 +190,8 @@ class DestinationClientDwhConfiguration(DestinationClientConfiguration): """How to handle replace disposition for this destination, can be classic or staging""" staging_dataset_name_layout: str = "%s_staging" """Layout for staging dataset, where %s is replaced with dataset name. placeholder is optional""" + enable_dataset_name_normalization: bool = True + """Whether to normalize the dataset name. Affects staging dataset as well.""" def _bind_dataset_name( self: TDestinationDwhClient, dataset_name: str, default_schema_name: str = None @@ -205,11 +210,14 @@ def normalize_dataset_name(self, schema: Schema) -> str: If default schema name is None or equals schema.name, the schema suffix is skipped. """ dataset_name = self._make_dataset_name(schema.name) - return ( - dataset_name - if not dataset_name - else schema.naming.normalize_table_identifier(dataset_name) - ) + if not dataset_name: + return dataset_name + else: + return ( + schema.naming.normalize_table_identifier(dataset_name) + if self.enable_dataset_name_normalization + else dataset_name + ) def normalize_staging_dataset_name(self, schema: Schema) -> str: """Builds staging dataset name out of dataset_name and staging_dataset_name_layout.""" @@ -224,7 +232,11 @@ def normalize_staging_dataset_name(self, schema: Schema) -> str: # no placeholder, then layout is a full name. so you can have a single staging dataset dataset_name = self.staging_dataset_name_layout - return schema.naming.normalize_table_identifier(dataset_name) + return ( + schema.naming.normalize_table_identifier(dataset_name) + if self.enable_dataset_name_normalization + else dataset_name + ) def _make_dataset_name(self, schema_name: str) -> str: if not schema_name: @@ -258,11 +270,45 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura """configuration of the staging, if present, injected at runtime""" -TLoadJobState = Literal["running", "failed", "retry", "completed"] +TLoadJobState = Literal["ready", "running", "failed", "retry", "completed"] + + +class LoadJob(ABC): + """ + A stateful load job, represents one job file + """ + + def __init__(self, file_path: str) -> None: + self._file_path = file_path + self._file_name = FileStorage.get_file_name_from_file_path(file_path) + # NOTE: we only accept a full filepath in the constructor + assert self._file_name != self._file_path + self._parsed_file_name = ParsedLoadJobFileName.parse(self._file_name) + + def job_id(self) -> str: + """The job id that is derived from the file name and does not changes during job lifecycle""" + return self._parsed_file_name.job_id() + + def file_name(self) -> str: + """A name of the job file""" + return self._file_name + + def job_file_info(self) -> ParsedLoadJobFileName: + return self._parsed_file_name + + @abstractmethod + def state(self) -> TLoadJobState: + """Returns current state. Should poll external resource if necessary.""" + pass + + @abstractmethod + def exception(self) -> str: + """The exception associated with failed or retry states""" + pass -class LoadJob: - """Represents a job that loads a single file +class RunnableLoadJob(LoadJob, ABC): + """Represents a runnable job that loads a single file Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. @@ -273,39 +319,80 @@ class LoadJob: immediately transition job into "failed" or "retry" state respectively. """ - def __init__(self, file_name: str) -> None: + def __init__(self, file_path: str) -> None: """ File name is also a job id (or job id is deterministically derived) so it must be globally unique """ # ensure file name - assert file_name == FileStorage.get_file_name_from_file_path(file_name) - self._file_name = file_name - self._parsed_file_name = ParsedLoadJobFileName.parse(file_name) + super().__init__(file_path) + self._state: TLoadJobState = "ready" + self._exception: Exception = None - @abstractmethod - def state(self) -> TLoadJobState: - """Returns current state. Should poll external resource if necessary.""" - pass + # variables needed by most jobs, set by the loader in set_run_vars + self._schema: Schema = None + self._load_table: TTableSchema = None + self._load_id: str = None + self._job_client: "JobClientBase" = None - def file_name(self) -> str: - """A name of the job file""" - return self._file_name + def set_run_vars(self, load_id: str, schema: Schema, load_table: TTableSchema) -> None: + """ + called by the loader right before the job is run + """ + self._load_id = load_id + self._schema = schema + self._load_table = load_table - def job_id(self) -> str: - """The job id that is derived from the file name and does not changes during job lifecycle""" - return self._parsed_file_name.job_id() + @property + def load_table_name(self) -> str: + return self._load_table["name"] - def job_file_info(self) -> ParsedLoadJobFileName: - return self._parsed_file_name + def run_managed( + self, + job_client: "JobClientBase", + ) -> None: + """ + wrapper around the user implemented run method + """ + # only jobs that are not running or have not reached a final state + # may be started + assert self._state in ("ready", "retry") + self._job_client = job_client + + # filepath is now moved to running + try: + self._state = "running" + self._job_client.prepare_load_job_execution(self) + self.run() + self._state = "completed" + except (DestinationTerminalException, TerminalValueError) as e: + self._state = "failed" + self._exception = e + except (DestinationTransientException, Exception) as e: + self._state = "retry" + self._exception = e + finally: + # sanity check + assert self._state in ("completed", "retry", "failed") @abstractmethod + def run(self) -> None: + """ + run the actual job, this will be executed on a thread and should be implemented by the user + exception will be handled outside of this function + """ + raise NotImplementedError() + + def state(self) -> TLoadJobState: + """Returns current state. Should poll external resource if necessary.""" + return self._state + def exception(self) -> str: """The exception associated with failed or retry states""" - pass + return str(self._exception) -class NewLoadJob(LoadJob): - """Adds a trait that allows to save new job file""" +class FollowupJob: + """Base class for follow up jobs that should be created""" @abstractmethod def new_file_path(self) -> str: @@ -313,35 +400,14 @@ def new_file_path(self) -> str: pass -class FollowupJob: - """Adds a trait that allows to create a followup job""" +class HasFollowupJobs: + """Adds a trait that allows to create single or table chain followup jobs""" - def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: """Return list of new jobs. `final_state` is state to which this job transits""" return [] -class DoNothingJob(LoadJob): - """The most lazy class of dlt""" - - def __init__(self, file_path: str) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - - def state(self) -> TLoadJobState: - # this job is always done - return "completed" - - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - - -class DoNothingFollowupJob(DoNothingJob, FollowupJob): - """The second most lazy class of dlt""" - - pass - - class JobClientBase(ABC): def __init__( self, @@ -394,13 +460,16 @@ def update_stored_schema( return expected_update @abstractmethod - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - """Creates and starts a load job for a particular `table` with content in `file_path`""" + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + """Creates a load job for a particular `table` with content in `file_path`""" pass - @abstractmethod - def restore_file_load(self, file_path: str) -> LoadJob: - """Finds and restores already started loading job identified by `file_path` if destination supports it.""" + def prepare_load_job_execution( # noqa: B027, optional override + self, job: RunnableLoadJob + ) -> None: + """Prepare the connected job client for the execution of a load job (used for query tags in sql clients)""" pass def should_truncate_table_before_load(self, table: TTableSchema) -> bool: @@ -410,7 +479,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] diff --git a/dlt/common/destination/utils.py b/dlt/common/destination/utils.py index 2c5e97df14..931413126c 100644 --- a/dlt/common/destination/utils.py +++ b/dlt/common/destination/utils.py @@ -6,7 +6,6 @@ from dlt.common.schema.exceptions import ( SchemaIdentifierNormalizationCollision, ) -from dlt.common.schema.utils import is_complete_column from dlt.common.typing import DictStrStr from .capabilities import DestinationCapabilitiesContext @@ -25,7 +24,6 @@ def verify_schema_capabilities( * Checks if schema has collisions due to case sensitivity of the identifiers """ - log = logger.warning if warnings else logger.info # collect all exceptions to show all problems in the schema exception_log: List[Exception] = [] # combined casing function @@ -79,7 +77,7 @@ def verify_schema_capabilities( ) column_name_lookup: DictStrStr = {} - for column_name, column in dict(table["columns"]).items(): + for column_name in dict(table["columns"]): # detect table name conflict cased_column_name = case_identifier(column_name) if cased_column_name in column_name_lookup: @@ -105,11 +103,4 @@ def verify_schema_capabilities( capabilities.max_column_identifier_length, ) ) - if not is_complete_column(column): - log( - f"A column {column_name} in table {table_name} in schema" - f" {schema.name} is incomplete. It was not bound to the data during" - " normalizations stage and its data type is unknown. Did you add this" - " column manually in code ie. as a merge key?" - ) return exception_log diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index 04100b0c6c..d98795d07c 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -1,13 +1,15 @@ -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, List from pathlib import Path -from dlt import version +from dlt import version, Pipeline from dlt.common import logger from dlt.common.libs.pyarrow import pyarrow as pa from dlt.common.libs.pyarrow import cast_arrow_schema_types from dlt.common.schema.typing import TWriteDisposition from dlt.common.exceptions import MissingDependencyException from dlt.common.storages import FilesystemConfiguration +from dlt.common.utils import assert_min_pkg_version +from dlt.destinations.impl.filesystem.filesystem import FilesystemClient try: from deltalake import write_deltalake, DeltaTable @@ -41,6 +43,13 @@ def ensure_delta_compatible_arrow_data( Casts `data` schema to replace data types not supported by Delta. """ + # RecordBatchReader.cast() requires pyarrow>=17.0.0 + # cast() got introduced in 16.0.0, but with bug + assert_min_pkg_version( + pkg_name="pyarrow", + version="17.0.0", + msg="`pyarrow>=17.0.0` is needed for `delta` table format on `filesystem` destination.", + ) schema = ensure_delta_compatible_arrow_schema(data.schema) return data.cast(schema) @@ -62,9 +71,13 @@ def write_delta_table( table_or_uri: Union[str, Path, DeltaTable], data: Union[pa.Table, pa.RecordBatchReader], write_disposition: TWriteDisposition, + partition_by: Optional[Union[List[str], str]] = None, storage_options: Optional[Dict[str, str]] = None, ) -> None: - """Writes in-memory Arrow table to on-disk Delta table.""" + """Writes in-memory Arrow table to on-disk Delta table. + + Thin wrapper around `deltalake.write_deltalake`. + """ # throws warning for `s3` protocol: https://github.com/delta-io/delta-rs/issues/2460 # TODO: upgrade `deltalake` lib after https://github.com/delta-io/delta-rs/pull/2500 @@ -72,6 +85,7 @@ def write_delta_table( write_deltalake( # type: ignore[call-overload] table_or_uri=table_or_uri, data=ensure_delta_compatible_arrow_data(data), + partition_by=partition_by, mode=get_delta_write_mode(write_disposition), schema_mode="merge", # enable schema evolution (adding new columns) storage_options=storage_options, @@ -79,6 +93,41 @@ def write_delta_table( ) +def get_delta_tables(pipeline: Pipeline, *tables: str) -> Dict[str, DeltaTable]: + """Returns Delta tables in `pipeline.default_schema` as `deltalake.DeltaTable` objects. + + Returned object is a dictionary with table names as keys and `DeltaTable` objects as values. + Optionally filters dictionary by table names specified as `*tables*`. + Raises ValueError if table name specified as `*tables` is not found. + """ + from dlt.common.schema.utils import get_table_format + + with pipeline.destination_client() as client: + assert isinstance( + client, FilesystemClient + ), "The `get_delta_tables` function requires a `filesystem` destination." + + schema_delta_tables = [ + t["name"] + for t in pipeline.default_schema.tables.values() + if get_table_format(pipeline.default_schema.tables, t["name"]) == "delta" + ] + if len(tables) > 0: + invalid_tables = set(tables) - set(schema_delta_tables) + if len(invalid_tables) > 0: + raise ValueError( + "Schema does not contain Delta tables with these names: " + f"{', '.join(invalid_tables)}." + ) + schema_delta_tables = [t for t in schema_delta_tables if t in tables] + table_dirs = client.get_table_dirs(schema_delta_tables, remote=True) + storage_options = _deltalake_storage_options(client.config) + return { + name: DeltaTable(_dir, storage_options=storage_options) + for name, _dir in zip(schema_delta_tables, table_dirs) + } + + def _deltalake_storage_options(config: FilesystemConfiguration) -> Dict[str, str]: """Returns dict that can be passed as `storage_options` in `deltalake` library.""" creds = {} diff --git a/dlt/common/runtime/collector.py b/dlt/common/runtime/collector.py index 95117b70cc..be5453cdd3 100644 --- a/dlt/common/runtime/collector.py +++ b/dlt/common/runtime/collector.py @@ -170,6 +170,7 @@ def update( total=total, ) self.messages[counter_key] = None + self.last_log_time = None self.counters[counter_key] += inc if message is not None: diff --git a/dlt/common/runtime/signals.py b/dlt/common/runtime/signals.py index 8d1cb3803e..a8fa70936e 100644 --- a/dlt/common/runtime/signals.py +++ b/dlt/common/runtime/signals.py @@ -32,6 +32,11 @@ def raise_if_signalled() -> None: raise SignalReceivedException(_received_signal) +def signal_received() -> bool: + """check if a signal was received""" + return True if _received_signal else False + + def sleep(sleep_seconds: float) -> None: """A signal-aware version of sleep function. Will raise SignalReceivedException if signal was received during sleep period.""" # do not allow sleeping if signal was received diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 2f016577ce..2e75b4b3a1 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -8,6 +8,7 @@ TSchemaEvolutionMode, ) from dlt.common.normalizers.naming import NamingConvention +from dlt.common.schema.typing import TColumnSchema, TColumnSchemaBase class SchemaException(DltException): @@ -231,3 +232,29 @@ def __init__( class ColumnNameConflictException(SchemaException): pass + + +class UnboundColumnException(SchemaException): + def __init__(self, schema_name: str, table_name: str, column: TColumnSchemaBase) -> None: + self.column = column + self.schema_name = schema_name + self.table_name = table_name + nullable: bool = column.get("nullable", False) + key_type: str = "" + if column.get("merge_key"): + key_type = "merge key" + elif column.get("primary_key"): + key_type = "primary key" + + msg = ( + f"The column {column['name']} in table {table_name} did not receive any data during" + " this load. " + ) + if key_type or not nullable: + msg += f"It is marked as non-nullable{' '+key_type} and it must have values. " + + msg += ( + "This can happen if you specify the column manually, for example using the 'merge_key'," + " 'primary_key' or 'columns' argument but it does not exist in the data." + ) + super().__init__(schema_name, msg) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 39db0e42ae..da9e581637 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -568,7 +568,7 @@ def data_tables( def data_table_names( self, seen_data_only: bool = False, include_incomplete: bool = False ) -> List[str]: - """Returns list of table table names. Excludes dlt table names.""" + """Returns list of table names. Excludes dlt table names.""" return [ t["name"] for t in self.data_tables( diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9a4dd51d4b..284c55caac 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -187,6 +187,7 @@ class TMergeDispositionDict(TWriteDispositionDict, total=False): strategy: Optional[TLoaderMergeStrategy] validity_column_names: Optional[List[str]] active_record_timestamp: Optional[TAnyDateTime] + boundary_timestamp: Optional[TAnyDateTime] row_version_column_name: Optional[str] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index aa5de9611c..8b87a7e5fe 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -352,6 +352,21 @@ def is_complete_column(col: TColumnSchemaBase) -> bool: return bool(col.get("name")) and bool(col.get("data_type")) +def is_nullable_column(col: TColumnSchemaBase) -> bool: + """Returns true if column is nullable""" + return col.get("nullable", True) + + +def find_incomplete_columns( + tables: List[TTableSchema], +) -> Iterable[Tuple[str, TColumnSchemaBase, bool]]: + """Yields (table_name, column, nullable) for all incomplete columns in `tables`""" + for table in tables: + for col in table["columns"].values(): + if not is_complete_column(col): + yield table["name"], col, is_nullable_column(col) + + def compare_complete_columns(a: TColumnSchema, b: TColumnSchema) -> bool: """Compares mandatory fields of complete columns""" assert is_complete_column(a) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 4d84094427..b0ed93f734 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -723,19 +723,12 @@ def build_job_file_name( @staticmethod def is_package_partially_loaded(package_info: LoadPackageInfo) -> bool: - """Checks if package is partially loaded - has jobs that are not new.""" - if package_info.state == "normalized": - pending_jobs: Sequence[TJobState] = ["new_jobs"] - else: - pending_jobs = ["completed_jobs", "failed_jobs"] - return ( - sum( - len(package_info.jobs[job_state]) - for job_state in WORKING_FOLDERS - if job_state not in pending_jobs - ) - > 0 - ) + """Checks if package is partially loaded - has jobs that are completed and jobs that are not.""" + all_jobs_count = sum(len(package_info.jobs[job_state]) for job_state in WORKING_FOLDERS) + completed_jobs_count = len(package_info.jobs["completed_jobs"]) + if completed_jobs_count and all_jobs_count - completed_jobs_count > 0: + return True + return False @staticmethod def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index fdd27161f7..ee11a77965 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -106,7 +106,7 @@ VARIANT_FIELD_FORMAT = "v_%s" TFileOrPath = Union[str, PathLike, IO[Any]] TSortOrder = Literal["asc", "desc"] -TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv"] +TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference"] """known loader file formats""" diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 4225d63fe7..371c1bae22 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -45,10 +45,10 @@ ) from dlt.common.schema.utils import table_schema_has_type from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import LoadJob, DoNothingFollowupJob, DoNothingJob -from dlt.common.destination.reference import NewLoadJob, SupportsStagingDestination +from dlt.common.destination.reference import LoadJob +from dlt.common.destination.reference import FollowupJob, SupportsStagingDestination from dlt.common.data_writers.escape import escape_hive_identifier -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob +from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob from dlt.destinations.typing import DBApi, DBTransaction from dlt.destinations.exceptions import ( @@ -65,6 +65,7 @@ ) from dlt.destinations.typing import DBApiCursor from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration from dlt.destinations.type_mapping import TypeMapper from dlt.destinations import path_utils @@ -160,7 +161,7 @@ def __init__(self) -> None: DLTAthenaFormatter._INSTANCE = self -class AthenaMergeJob(SqlMergeJob): +class AthenaMergeJob(SqlMergeFollowupJob): @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: # reproducible name so we know which table to drop @@ -468,7 +469,9 @@ def _get_table_update_sql( LOCATION '{location}';""") return sql - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" if table_schema_has_type(table, "time"): raise LoadJobTerminalException( @@ -476,32 +479,38 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> "Athena cannot load TIME columns from parquet tables. Please convert" " `datetime.time` objects in your data to `str` or `datetime.datetime`.", ) - job = super().start_file_load(table, file_path, load_id) + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = ( - DoNothingFollowupJob(file_path) + FinalizedLoadJobWithFollowupJobs(file_path) if self._is_iceberg_table(self.prepare_load_table(table["name"])) - else DoNothingJob(file_path) + else FinalizedLoadJob(file_path) ) return job - def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_append_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJob]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False}) + SqlStagingCopyFollowupJob.from_table_chain( + table_chain, self.sql_client, {"replace": False} + ) ] return super()._create_append_followup_jobs(table_chain) def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) + SqlStagingCopyFollowupJob.from_table_chain( + table_chain, self.sql_client, {"replace": True} + ) ] return super()._create_replace_followup_jobs(table_chain) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [AthenaMergeJob.from_table_chain(table_chain, self.sql_client)] def _is_iceberg_table(self, table: TTableSchema) -> bool: diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 095974d186..c6bf2e7654 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -1,6 +1,7 @@ import functools import os from pathlib import Path +import time from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, cast import google.cloud.bigquery as bigquery # noqa: I250 @@ -10,14 +11,16 @@ from google.cloud.bigquery.retry import _RETRYABLE_REASONS from dlt.common import logger +from dlt.common.runtime.signals import sleep from dlt.common.json import json from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( + HasFollowupJobs, FollowupJob, - NewLoadJob, TLoadJobState, - LoadJob, + RunnableLoadJob, SupportsStagingDestination, + LoadJob, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat @@ -33,7 +36,7 @@ DatabaseUndefinedRelation, DestinationSchemaWillNotUpdate, DestinationTerminalException, - LoadJobNotExistsException, + DatabaseTerminalException, LoadJobTerminalException, ) from dlt.destinations.impl.bigquery.bigquery_adapter import ( @@ -48,8 +51,8 @@ from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.utils import parse_db_data_type_str_with_precision @@ -104,60 +107,95 @@ def from_db_type( return super().from_db_type(*parse_db_data_type_str_with_precision(db_type)) -class BigQueryLoadJob(LoadJob, FollowupJob): +class BigQueryLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, - file_name: str, - bq_load_job: bigquery.LoadJob, + file_path: str, http_timeout: float, retry_deadline: float, ) -> None: - self.bq_load_job = bq_load_job - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) - self.http_timeout = http_timeout - super().__init__(file_name) - - def state(self) -> TLoadJobState: - if not self.bq_load_job.done(retry=self.default_retry, timeout=self.http_timeout): - return "running" - if self.bq_load_job.output_rows is not None and self.bq_load_job.error_result is None: - return "completed" - reason = self.bq_load_job.error_result.get("reason") - if reason in BQ_TERMINAL_REASONS: - # the job permanently failed for the reason above - return "failed" - elif reason in ["internalError"]: - logger.warning( - f"Got reason {reason} for job {self.file_name}, job considered still" - f" running. ({self.bq_load_job.error_result})" - ) - # the status of the job couldn't be obtained, job still running. - return "running" - else: - # retry on all other reasons, including `backendError` which requires retry when the job is done. - return "retry" - - def bigquery_job_id(self) -> str: - return BigQueryLoadJob.get_job_id_from_file_path(super().file_name()) + super().__init__(file_path) + self._default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) + self._http_timeout = http_timeout + self._job_client: "BigQueryClient" = None + self._bq_load_job: bigquery.LoadJob = None + # vars only used for testing + self._created_job = False + self._resumed_job = False + + def run(self) -> None: + # start the job (or retrieve in case it already exists) + try: + self._bq_load_job = self._job_client._create_load_job(self._load_table, self._file_path) + self._created_job = True + except api_core_exceptions.GoogleAPICallError as gace: + reason = BigQuerySqlClient._get_reason_from_errors(gace) + if reason == "notFound": + # google.api_core.exceptions.NotFound: 404 – table not found + raise DatabaseUndefinedRelation(gace) from gace + elif ( + reason == "duplicate" + ): # google.api_core.exceptions.Conflict: 409 PUT – already exists + self._bq_load_job = self._job_client._retrieve_load_job(self._file_path) + self._resumed_job = True + logger.info( + f"Found existing bigquery job for job {self._file_name}, will resume job." + ) + elif reason in BQ_TERMINAL_REASONS: + # google.api_core.exceptions.BadRequest - will not be processed ie bad job name + raise LoadJobTerminalException( + self._file_path, f"The server reason was: {reason}" + ) from gace + else: + raise DatabaseTransientException(gace) from gace + + # we loop on the job thread until we detect a status change + while True: + sleep(1) + # not done yet + if not self._bq_load_job.done(retry=self._default_retry, timeout=self._http_timeout): + continue + # done, break loop and go to completed state + if self._bq_load_job.output_rows is not None and self._bq_load_job.error_result is None: + break + reason = self._bq_load_job.error_result.get("reason") + if reason in BQ_TERMINAL_REASONS: + # the job permanently failed for the reason above + raise DatabaseTerminalException( + Exception( + f"Bigquery Load Job failed, reason reported from bigquery: '{reason}'" + ) + ) + elif reason in ["internalError"]: + logger.warning( + f"Got reason {reason} for job {self._file_name}, job considered still" + f" running. ({self._bq_load_job.error_result})" + ) + continue + else: + raise DatabaseTransientException( + Exception( + f"Bigquery Job needs to be retried, reason reported from bigquer '{reason}'" + ) + ) def exception(self) -> str: - exception: str = json.dumps( + return json.dumps( { - "error_result": self.bq_load_job.error_result, - "errors": self.bq_load_job.errors, - "job_start": self.bq_load_job.started, - "job_end": self.bq_load_job.ended, - "job_id": self.bq_load_job.job_id, + "error_result": self._bq_load_job.error_result, + "errors": self._bq_load_job.errors, + "job_start": self._bq_load_job.started, + "job_end": self._bq_load_job.ended, + "job_id": self._bq_load_job.job_id, } ) - return exception @staticmethod def get_job_id_from_file_path(file_path: str) -> str: return Path(file_path).name.replace(".", "_") -class BigQueryMergeJob(SqlMergeJob): +class BigQueryMergeJob(SqlMergeFollowupJob): @classmethod def gen_key_table_clauses( cls, @@ -187,6 +225,7 @@ def __init__( config.credentials, capabilities, config.get_location(), + config.project_id, config.http_timeout, config.retry_deadline, ) @@ -195,97 +234,46 @@ def __init__( self.sql_client: BigQuerySqlClient = sql_client # type: ignore self.type_mapper = BigQueryTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] - def restore_file_load(self, file_path: str) -> LoadJob: - """Returns a completed SqlLoadJob or restored BigQueryLoadJob - - See base class for details on SqlLoadJob. - BigQueryLoadJob is restored with a job ID derived from `file_path`. - - Args: - file_path (str): a path to a job file. - - Returns: - LoadJob: completed SqlLoadJob or restored BigQueryLoadJob - """ - job = super().restore_file_load(file_path) - if not job: - try: - job = BigQueryLoadJob( - FileStorage.get_file_name_from_file_path(file_path), - self._retrieve_load_job(file_path), - self.config.http_timeout, - self.config.retry_deadline, - ) - except api_core_exceptions.GoogleAPICallError as gace: - reason = BigQuerySqlClient._get_reason_from_errors(gace) - if reason == "notFound": - raise LoadJobNotExistsException(file_path) from gace - elif reason in BQ_TERMINAL_REASONS: - raise LoadJobTerminalException( - file_path, f"The server reason was: {reason}" - ) from gace - else: - raise DatabaseTransientException(gace) from gace - return job - - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id) if not job: insert_api = table.get("x-insert-api", "default") - try: - if insert_api == "streaming": - if table["write_disposition"] != "append": - raise DestinationTerminalException( - "BigQuery streaming insert can only be used with `append`" - " write_disposition, while the given resource has" - f" `{table['write_disposition']}`." - ) - if file_path.endswith(".jsonl"): - job_cls = DestinationJsonlLoadJob - elif file_path.endswith(".parquet"): - job_cls = DestinationParquetLoadJob # type: ignore - else: - raise ValueError( - f"Unsupported file type for BigQuery streaming inserts: {file_path}" - ) - - job = job_cls( - table, - file_path, - self.config, # type: ignore - self.schema, - destination_state(), - functools.partial(_streaming_load, self.sql_client), - [], + if insert_api == "streaming": + if table["write_disposition"] != "append": + raise DestinationTerminalException( + "BigQuery streaming insert can only be used with `append`" + " write_disposition, while the given resource has" + f" `{table['write_disposition']}`." ) + if file_path.endswith(".jsonl"): + job_cls = DestinationJsonlLoadJob + elif file_path.endswith(".parquet"): + job_cls = DestinationParquetLoadJob # type: ignore else: - job = BigQueryLoadJob( - FileStorage.get_file_name_from_file_path(file_path), - self._create_load_job(table, file_path), - self.config.http_timeout, - self.config.retry_deadline, + raise ValueError( + f"Unsupported file type for BigQuery streaming inserts: {file_path}" ) - except api_core_exceptions.GoogleAPICallError as gace: - reason = BigQuerySqlClient._get_reason_from_errors(gace) - if reason == "notFound": - # google.api_core.exceptions.NotFound: 404 – table not found - raise DatabaseUndefinedRelation(gace) from gace - elif ( - reason == "duplicate" - ): # google.api_core.exceptions.Conflict: 409 PUT – already exists - return self.restore_file_load(file_path) - elif reason in BQ_TERMINAL_REASONS: - # google.api_core.exceptions.BadRequest - will not be processed ie bad job name - raise LoadJobTerminalException( - file_path, f"The server reason was: {reason}" - ) from gace - else: - raise DatabaseTransientException(gace) from gace + job = job_cls( + file_path, + self.config, # type: ignore + destination_state(), + _streaming_load, # type: ignore + [], + callable_requires_job_client_args=True, + ) + else: + job = BigQueryLoadJob( + file_path, + self.config.http_timeout, + self.config.retry_deadline, + ) return job def _get_table_update_sql( @@ -445,8 +433,8 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load # determine whether we load from local or uri bucket_path = None ext: str = os.path.splitext(file_path)[1][1:] - if NewReferenceJob.is_reference_job(file_path): - bucket_path = NewReferenceJob.resolve_reference(file_path) + if ReferenceFollowupJob.is_reference_job(file_path): + bucket_path = ReferenceFollowupJob.resolve_reference(file_path) ext = os.path.splitext(bucket_path)[1][1:] # Select a correct source format @@ -515,7 +503,7 @@ def _should_autodetect_schema(self, table_name: str) -> bool: def _streaming_load( - sql_client: SqlClientBase[BigQueryClient], items: List[Dict[Any, Any]], table: Dict[str, Any] + items: List[Dict[Any, Any]], table: Dict[str, Any], job_client: BigQueryClient ) -> None: """ Upload the given items into BigQuery table, using streaming API. @@ -542,6 +530,8 @@ def _should_retry(exc: api_core_exceptions.GoogleAPICallError) -> bool: reason = exc.errors[0]["reason"] return reason in _RETRYABLE_REASONS + sql_client = job_client.sql_client + full_name = sql_client.make_qualified_table_name(table["name"], escape=False) bq_client = sql_client._client diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index 47cc997a4a..3d71b0c8ea 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -1,6 +1,6 @@ import dataclasses import warnings -from typing import ClassVar, List, Final +from typing import ClassVar, List, Final, Optional from dlt.common.configuration import configspec from dlt.common.configuration.specs import GcpServiceAccountCredentials @@ -14,6 +14,8 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_type: Final[str] = dataclasses.field(default="bigquery", init=False, repr=False, compare=False) # type: ignore credentials: GcpServiceAccountCredentials = None location: str = "US" + project_id: Optional[str] = None + """Note, that this is BigQuery project_id which could be different from credentials.project_id""" has_case_sensitive_identifiers: bool = True """If True then dlt expects to load data into case sensitive dataset""" should_set_case_sensitivity_on_new_dataset: bool = False diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index dfc4094e7b..c56742f1ff 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -82,14 +82,16 @@ def __init__( credentials: GcpServiceAccountCredentialsWithoutDefaults, capabilities: DestinationCapabilitiesContext, location: str = "US", + project_id: Optional[str] = None, http_timeout: float = 15.0, retry_deadline: float = 60.0, ) -> None: self._client: bigquery.Client = None self.credentials: GcpServiceAccountCredentialsWithoutDefaults = credentials self.location = location + self.project_id = project_id or self.credentials.project_id self.http_timeout = http_timeout - super().__init__(credentials.project_id, dataset_name, staging_dataset_name, capabilities) + super().__init__(self.project_id, dataset_name, staging_dataset_name, capabilities) self._default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) self._default_query = bigquery.QueryJobConfig( @@ -100,7 +102,7 @@ def __init__( @raise_open_connection_error def open_connection(self) -> bigquery.Client: self._client = bigquery.Client( - self.credentials.project_id, + self.project_id, credentials=self.credentials.to_native_credentials(), location=self.location, ) @@ -240,7 +242,7 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB conn.close() def catalog_name(self, escape: bool = True) -> Optional[str]: - project_id = self.capabilities.casefold_identifier(self.credentials.project_id) + project_id = self.capabilities.casefold_identifier(self.project_id) if escape: project_id = self.capabilities.escape_identifier(project_id) return project_id diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 148fca3f1e..5bd34e0e0d 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -18,9 +18,10 @@ from dlt.common.destination.reference import ( SupportsStagingDestination, TLoadJobState, + HasFollowupJobs, + RunnableLoadJob, FollowupJob, LoadJob, - NewLoadJob, ) from dlt.common.schema import Schema, TColumnSchema from dlt.common.schema.typing import ( @@ -51,8 +52,8 @@ SqlJobClientBase, SqlJobClientWithStaging, ) -from dlt.destinations.job_impl import NewReferenceJob, EmptyLoadJob -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.job_impl import ReferenceFollowupJob, FinalizedLoadJobWithFollowupJobs +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper @@ -123,22 +124,25 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class ClickHouseLoadJob(LoadJob, FollowupJob): +class ClickHouseLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, file_path: str, - table_name: str, - client: ClickHouseSqlClient, staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) + super().__init__(file_path) + self._job_client: "ClickHouseClient" = None + self._staging_credentials = staging_credentials + + def run(self) -> None: + client = self._job_client.sql_client - qualified_table_name = client.make_qualified_table_name(table_name) + qualified_table_name = client.make_qualified_table_name(self.load_table_name) bucket_path = None + file_name = self._file_name - if NewReferenceJob.is_reference_job(file_path): - bucket_path = NewReferenceJob.resolve_reference(file_path) + if ReferenceFollowupJob.is_reference_job(self._file_path): + bucket_path = ReferenceFollowupJob.resolve_reference(self._file_path) file_name = FileStorage.get_file_name_from_file_path(bucket_path) bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme @@ -152,7 +156,7 @@ def __init__( if not bucket_path: # Local filesystem. if ext == "jsonl": - compression = "gz" if FileStorage.is_gzipped(file_path) else "none" + compression = "gz" if FileStorage.is_gzipped(self._file_path) else "none" try: with clickhouse_connect.create_client( host=client.credentials.host, @@ -165,7 +169,7 @@ def __init__( insert_file( clickhouse_connect_client, qualified_table_name, - file_path, + self._file_path, fmt=clickhouse_format, settings={ "allow_experimental_lightweight_delete": 1, @@ -176,7 +180,7 @@ def __init__( ) except clickhouse_connect.driver.exceptions.Error as e: raise LoadJobTerminalException( - file_path, + self._file_path, f"ClickHouse connection failed due to {e}.", ) from e return @@ -188,9 +192,9 @@ def __init__( compression = "none" if config.get("data_writer.disable_compression") else "gz" if bucket_scheme in ("s3", "gs", "gcs"): - if not isinstance(staging_credentials, AwsCredentialsWithoutDefaults): + if not isinstance(self._staging_credentials, AwsCredentialsWithoutDefaults): raise LoadJobTerminalException( - file_path, + self._file_path, dedent( """ Google Cloud Storage buckets must be configured using the S3 compatible access pattern. @@ -201,10 +205,10 @@ def __init__( ) bucket_http_url = convert_storage_to_http_scheme( - bucket_url, endpoint=staging_credentials.endpoint_url + bucket_url, endpoint=self._staging_credentials.endpoint_url ) - access_key_id = staging_credentials.aws_access_key_id - secret_access_key = staging_credentials.aws_secret_access_key + access_key_id = self._staging_credentials.aws_access_key_id + secret_access_key = self._staging_credentials.aws_secret_access_key auth = "NOSIGN" if access_key_id and secret_access_key: auth = f"'{access_key_id}','{secret_access_key}'" @@ -214,24 +218,22 @@ def __init__( ) elif bucket_scheme in ("az", "abfs"): - if not isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + if not isinstance(self._staging_credentials, AzureCredentialsWithoutDefaults): raise LoadJobTerminalException( - file_path, + self._file_path, "Unsigned Azure Blob Storage access from ClickHouse isn't supported as yet.", ) # Authenticated access. - account_name = staging_credentials.azure_storage_account_name - storage_account_url = ( - f"https://{staging_credentials.azure_storage_account_name}.blob.core.windows.net" - ) - account_key = staging_credentials.azure_storage_account_key + account_name = self._staging_credentials.azure_storage_account_name + storage_account_url = f"https://{self._staging_credentials.azure_storage_account_name}.blob.core.windows.net" + account_key = self._staging_credentials.azure_storage_account_key # build table func table_function = f"azureBlobStorage('{storage_account_url}','{bucket_url.netloc}','{bucket_url.path}','{account_name}','{account_key}','{clickhouse_format}','{compression}')" else: raise LoadJobTerminalException( - file_path, + self._file_path, f"ClickHouse loader does not support '{bucket_scheme}' filesystem.", ) @@ -239,14 +241,8 @@ def __init__( with client.begin_transaction(): client.execute_sql(statement) - def state(self) -> TLoadJobState: - return "completed" - def exception(self) -> str: - raise NotImplementedError() - - -class ClickHouseMergeJob(SqlMergeJob): +class ClickHouseMergeJob(SqlMergeFollowupJob): @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};" @@ -292,7 +288,7 @@ def __init__( self.active_hints = deepcopy(HINT_TO_CLICKHOUSE_ATTR) self.type_mapper = ClickHouseTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [ClickHouseMergeJob.from_table_chain(table_chain, self.sql_client)] def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -319,11 +315,11 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non .strip() ) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - return super().start_file_load(table, file_path, load_id) or ClickHouseLoadJob( + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + return super().create_load_job(table, file_path, load_id, restore) or ClickHouseLoadJob( file_path, - table["name"], - self.sql_client, staging_credentials=( self.config.staging_config.credentials if self.config.staging_config else None ), @@ -374,6 +370,3 @@ def _from_db_type( self, ch_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: return self.type_mapper.from_db_type(ch_t, precision, scale) - - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index fbe7fa4c6b..0a203c21b6 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -4,12 +4,13 @@ from dlt import config from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( + HasFollowupJobs, FollowupJob, - NewLoadJob, TLoadJobState, - LoadJob, + RunnableLoadJob, CredentialsConfiguration, SupportsStagingDestination, + LoadJob, ) from dlt.common.configuration.specs import ( AwsCredentialsWithoutDefaults, @@ -25,12 +26,12 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlMergeJob -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper @@ -103,30 +104,31 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class DatabricksLoadJob(LoadJob, FollowupJob): +class DatabricksLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, - table: TTableSchema, file_path: str, - table_name: str, - load_id: str, - client: DatabricksSqlClient, staging_config: FilesystemConfiguration, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) - staging_credentials = staging_config.credentials + super().__init__(file_path) + self._staging_config = staging_config + self._job_client: "DatabricksClient" = None - qualified_table_name = client.make_qualified_table_name(table_name) + def run(self) -> None: + self._sql_client = self._job_client.sql_client + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) + staging_credentials = self._staging_config.credentials # extract and prepare some vars bucket_path = orig_bucket_path = ( - NewReferenceJob.resolve_reference(file_path) - if NewReferenceJob.is_reference_job(file_path) + ReferenceFollowupJob.resolve_reference(self._file_path) + if ReferenceFollowupJob.is_reference_job(self._file_path) else "" ) file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + FileStorage.get_file_name_from_file_path(bucket_path) + if bucket_path + else self._file_name ) from_clause = "" credentials_clause = "" @@ -166,13 +168,13 @@ def __init__( from_clause = f"FROM '{bucket_path}'" else: raise LoadJobTerminalException( - file_path, + self._file_path, f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" " azure buckets are supported", ) else: raise LoadJobTerminalException( - file_path, + self._file_path, "Cannot load from local file. Databricks does not support loading from local files." " Configure staging with an s3 or azure storage bucket.", ) @@ -183,32 +185,32 @@ def __init__( elif file_name.endswith(".jsonl"): if not config.get("data_writer.disable_compression"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader does not support gzip compressed JSON files. Please disable" " compression in the data writer configuration:" " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) - if table_schema_has_type(table, "decimal"): + if table_schema_has_type(self._load_table, "decimal"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load DECIMAL type columns from json files. Switch to" " parquet format to load decimals.", ) - if table_schema_has_type(table, "binary"): + if table_schema_has_type(self._load_table, "binary"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load BINARY type columns from json files. Switch to" " parquet format to load byte values.", ) - if table_schema_has_type(table, "complex"): + if table_schema_has_type(self._load_table, "complex"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load complex columns (lists and dicts) from json" " files. Switch to parquet format to load complex types.", ) - if table_schema_has_type(table, "date"): + if table_schema_has_type(self._load_table, "date"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load DATE type columns from json files. Switch to" " parquet format to load dates.", ) @@ -216,7 +218,7 @@ def __init__( source_format = "JSON" format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" # Databricks fails when trying to load empty json files, so we have to check the file size - fs, _ = fsspec_from_config(staging_config) + fs, _ = fsspec_from_config(self._staging_config) file_size = fs.size(orig_bucket_path) if file_size == 0: # Empty file, do nothing return @@ -227,16 +229,10 @@ def __init__( FILEFORMAT = {source_format} {format_options_clause} """ - client.execute_sql(statement) + self._sql_client.execute_sql(statement) - def state(self) -> TLoadJobState: - return "completed" - def exception(self) -> str: - raise NotImplementedError() - - -class DatabricksMergeJob(SqlMergeJob): +class DatabricksMergeJob(SqlMergeFollowupJob): @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: return f"CREATE TEMPORARY VIEW {temp_table_name} AS {select_sql};" @@ -271,24 +267,19 @@ def __init__( self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] self.type_mapper = DatabricksTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = DatabricksLoadJob( - table, file_path, - table["name"], - load_id, - self.sql_client, staging_config=cast(FilesystemConfiguration, self.config.staging_config), ) return job - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index 976dfa4fb5..0c4da81471 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -2,6 +2,7 @@ from types import TracebackType from typing import ClassVar, Optional, Type, Iterable, cast, List +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob from dlt.common.destination.reference import LoadJob from dlt.common.typing import AnyFun from dlt.common.storages.load_package import destination_state @@ -10,12 +11,10 @@ from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - LoadJob, - DoNothingJob, JobClientBase, + LoadJob, ) -from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration from dlt.destinations.job_impl import ( DestinationJsonlLoadJob, @@ -56,44 +55,49 @@ def update_stored_schema( ) -> Optional[TSchemaTables]: return super().update_stored_schema(only_tables, expected_update) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: # skip internal tables and remove columns from schema if so configured - skipped_columns: List[str] = [] if self.config.skip_dlt_columns_and_tables: if table["name"].startswith(self.schema._dlt_tables_prefix): - return DoNothingJob(file_path) - table = deepcopy(table) - for column in list(table["columns"].keys()): + return FinalizedLoadJob(file_path) + + skipped_columns: List[str] = [] + if self.config.skip_dlt_columns_and_tables: + for column in list(self.schema.tables[table["name"]]["columns"].keys()): if column.startswith(self.schema._dlt_tables_prefix): - table["columns"].pop(column) skipped_columns.append(column) # save our state in destination name scope load_state = destination_state() if file_path.endswith("parquet"): return DestinationParquetLoadJob( - table, file_path, self.config, - self.schema, load_state, self.destination_callable, skipped_columns, ) if file_path.endswith("jsonl"): return DestinationJsonlLoadJob( - table, file_path, self.config, - self.schema, load_state, self.destination_callable, skipped_columns, ) return None - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") + def prepare_load_table( + self, table_name: str, prepare_for_staging: bool = False + ) -> TTableSchema: + table = super().prepare_load_table(table_name, prepare_for_staging) + if self.config.skip_dlt_columns_and_tables: + for column in list(table["columns"].keys()): + if column.startswith(self.schema._dlt_tables_prefix): + table["columns"].pop(column) + return table def complete_load(self, load_id: str) -> None: ... diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index bea18cdea5..3611665f6c 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -3,11 +3,12 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - FollowupJob, + HasFollowupJobs, TLoadJobState, - LoadJob, + RunnableLoadJob, SupportsStagingDestination, - NewLoadJob, + FollowupJob, + LoadJob, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat, TColumnSchemaBase @@ -17,9 +18,9 @@ from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration from dlt.destinations.impl.dremio.sql_client import DremioSqlClient from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs +from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.sql_client import SqlClientBase @@ -69,7 +70,7 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class DremioMergeJob(SqlMergeJob): +class DremioMergeJob(SqlMergeFollowupJob): @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: return sql_client.make_qualified_table_name(f"_temp_{name_prefix}_{uniq_id()}") @@ -83,23 +84,25 @@ def default_order_by(cls) -> str: return "NULL" -class DremioLoadJob(LoadJob, FollowupJob): +class DremioLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, file_path: str, - table_name: str, - client: DremioSqlClient, stage_name: Optional[str] = None, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) + super().__init__(file_path) + self._stage_name = stage_name + self._job_client: "DremioClient" = None - qualified_table_name = client.make_qualified_table_name(table_name) + def run(self) -> None: + self._sql_client = self._job_client.sql_client + + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) # extract and prepare some vars bucket_path = ( - NewReferenceJob.resolve_reference(file_path) - if NewReferenceJob.is_reference_job(file_path) + ReferenceFollowupJob.resolve_reference(self._file_path) + if ReferenceFollowupJob.is_reference_job(self._file_path) else "" ) @@ -107,33 +110,29 @@ def __init__( raise RuntimeError("Could not resolve bucket path.") file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + FileStorage.get_file_name_from_file_path(bucket_path) + if bucket_path + else self._file_name ) bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - if bucket_scheme == "s3" and stage_name: + if bucket_scheme == "s3" and self._stage_name: from_clause = ( - f"FROM '@{stage_name}/{bucket_url.hostname}/{bucket_url.path.lstrip('/')}'" + f"FROM '@{self._stage_name}/{bucket_url.hostname}/{bucket_url.path.lstrip('/')}'" ) else: raise LoadJobTerminalException( - file_path, "Only s3 staging currently supported in Dremio destination" + self._file_path, "Only s3 staging currently supported in Dremio destination" ) source_format = file_name.split(".")[-1] - client.execute_sql(f"""COPY INTO {qualified_table_name} + self._sql_client.execute_sql(f"""COPY INTO {qualified_table_name} {from_clause} FILE_FORMAT '{source_format}' """) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class DremioClient(SqlJobClientWithStaging, SupportsStagingDestination): def __init__( @@ -153,21 +152,18 @@ def __init__( self.sql_client: DremioSqlClient = sql_client # type: ignore self.type_mapper = DremioTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = DremioLoadJob( file_path=file_path, - table_name=table["name"], - client=self.sql_client, stage_name=self.config.staging_data_source, ) return job - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - def _get_table_update_sql( self, table_name: str, @@ -205,7 +201,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 10d4fc13de..3d5905ff40 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -5,7 +5,7 @@ from dlt.common.data_types import TDataType from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState +from dlt.common.destination.reference import RunnableLoadJob, HasFollowupJobs, LoadJob from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import maybe_context @@ -19,10 +19,6 @@ HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} -# duckdb cannot load PARQUET to the same table in parallel. so serialize it per table -PARQUET_TABLE_LOCK = threading.Lock() -TABLES_LOCKS: Dict[str, threading.Lock] = {} - class DuckDbTypeMapper(TypeMapper): sct_to_unbound_dbt = { @@ -113,40 +109,30 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class DuckDbCopyJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - - qualified_table_name = sql_client.make_qualified_table_name(table_name) - if file_path.endswith("parquet"): - source_format = "PARQUET" - options = "" - # lock when creating a new lock - with PARQUET_TABLE_LOCK: - # create or get lock per table name - lock: threading.Lock = TABLES_LOCKS.setdefault( - qualified_table_name, threading.Lock() - ) - elif file_path.endswith("jsonl"): - # NOTE: loading JSON does not work in practice on duckdb: the missing keys fail the load instead of being interpreted as NULL - source_format = "JSON" # newline delimited, compression auto - options = ", COMPRESSION GZIP" if FileStorage.is_gzipped(file_path) else "" - lock = None - else: - raise ValueError(file_path) +class DuckDbCopyJob(RunnableLoadJob, HasFollowupJobs): + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: "DuckDbClient" = None - with maybe_context(lock): - with sql_client.begin_transaction(): - sql_client.execute_sql( - f"COPY {qualified_table_name} FROM '{file_path}' ( FORMAT" - f" {source_format} {options});" - ) + def run(self) -> None: + self._sql_client = self._job_client.sql_client - def state(self) -> TLoadJobState: - return "completed" + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) + if self._file_path.endswith("parquet"): + source_format = "read_parquet" + options = ", union_by_name=true" + elif self._file_path.endswith("jsonl"): + # NOTE: loading JSON does not work in practice on duckdb: the missing keys fail the load instead of being interpreted as NULL + source_format = "read_json" # newline delimited, compression auto + options = ", COMPRESSION=GZIP" if FileStorage.is_gzipped(self._file_path) else "" + else: + raise ValueError(self._file_path) - def exception(self) -> str: - raise NotImplementedError() + with self._sql_client.begin_transaction(): + self._sql_client.execute_sql( + f"INSERT INTO {qualified_table_name} BY NAME SELECT * FROM" + f" {source_format}('{self._file_path}' {options});" + ) class DuckDbClient(InsertValuesJobClient): @@ -168,10 +154,12 @@ def __init__( self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = DuckDbTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: - job = DuckDbCopyJob(table["name"], file_path, self.sql_client) + job = DuckDbCopyJob(file_path) return job def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index a9fdb1f47d..7bc1d9e943 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -21,13 +21,29 @@ class DummyClientConfiguration(DestinationClientConfiguration): loader_file_format: TLoaderFileFormat = "jsonl" fail_schema_update: bool = False fail_prob: float = 0.0 + """probability of terminal fail""" retry_prob: float = 0.0 + """probability of job retry""" completed_prob: float = 0.0 + """probablibitly of successful job completion""" exception_prob: float = 0.0 - """probability of exception when checking job status""" + """probability of exception transient exception when running job""" timeout: float = 10.0 - fail_in_init: bool = True + """timeout time""" + fail_terminally_in_init: bool = False + """raise terminal exception in job init""" + fail_transiently_in_init: bool = False + """raise transient exception in job init""" + # new jobs workflows create_followup_jobs: bool = False - + """create followup job for individual jobs""" + fail_followup_job_creation: bool = False + """Raise generic exception during followupjob creation""" + fail_table_chain_followup_job_creation: bool = False + """Raise generic exception during tablechain followupjob creation""" + create_followup_table_chain_sql_jobs: bool = False + """create a table chain merge job which is guaranteed to fail""" + create_followup_table_chain_reference_jobs: bool = False + """create table chain jobs which succeed """ credentials: DummyClientCredentials = None diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index c41b7dca61..7d406c969f 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -12,7 +12,8 @@ Iterable, List, ) - +import os +import time from dlt.common.pendulum import pendulum from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.storages import FileStorage @@ -23,79 +24,88 @@ DestinationTransientException, ) from dlt.common.destination.reference import ( + HasFollowupJobs, FollowupJob, - NewLoadJob, SupportsStagingDestination, TLoadJobState, - LoadJob, + RunnableLoadJob, JobClientBase, WithStagingDataset, + LoadJob, ) +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.exceptions import ( LoadJobNotExistsException, LoadJobInvalidStateTransitionException, ) from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob -class LoadDummyBaseJob(LoadJob): +class LoadDummyBaseJob(RunnableLoadJob): def __init__(self, file_name: str, config: DummyClientConfiguration) -> None: + super().__init__(file_name) self.config = copy(config) - self._status: TLoadJobState = "running" - self._exception: str = None self.start_time: float = pendulum.now().timestamp() - super().__init__(file_name) - if config.fail_in_init: - s = self.state() - if s == "failed": - raise DestinationTerminalException(self._exception) - if s == "retry": - raise DestinationTransientException(self._exception) - - def state(self) -> TLoadJobState: - # this should poll the server for a job status, here we simulate various outcomes - if self._status == "running": + + if self.config.fail_terminally_in_init: + raise DestinationTerminalException(self._exception) + if self.config.fail_transiently_in_init: + raise Exception(self._exception) + + def run(self) -> None: + while True: + # simulate generic exception (equals retry) c_r = random.random() if self.config.exception_prob >= c_r: - raise DestinationTransientException("Dummy job status raised exception") + # this will make the job go to a retry state with a generic exception + raise Exception("Dummy job status raised exception") + + # timeout condition (terminal) n = pendulum.now().timestamp() if n - self.start_time > self.config.timeout: - self._status = "failed" - self._exception = "failed due to timeout" - else: - c_r = random.random() - if self.config.completed_prob >= c_r: - self._status = "completed" - else: - c_r = random.random() - if self.config.retry_prob >= c_r: - self._status = "retry" - self._exception = "a random retry occured" - else: - c_r = random.random() - if self.config.fail_prob >= c_r: - self._status = "failed" - self._exception = "a random fail occured" - - return self._status - - def exception(self) -> str: - # this will typically call server for error messages - return self._exception - - def retry(self) -> None: - if self._status != "retry": - raise LoadJobInvalidStateTransitionException(self._status, "retry") - self._status = "retry" - - -class LoadDummyJob(LoadDummyBaseJob, FollowupJob): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + # this will make the the job go to a failed state + raise DestinationTerminalException("failed due to timeout") + + # success + c_r = random.random() + if self.config.completed_prob >= c_r: + # this will make the run function exit and the job go to a completed state + break + + # retry prob + c_r = random.random() + if self.config.retry_prob >= c_r: + # this will make the job go to a retry state + raise DestinationTransientException("a random retry occured") + + # fail prob + c_r = random.random() + if self.config.fail_prob >= c_r: + # this will make the the job go to a failed state + raise DestinationTerminalException("a random fail occured") + + time.sleep(0.1) + + +class DummyFollowupJob(ReferenceFollowupJob): + def __init__( + self, original_file_name: str, remote_paths: List[str], config: DummyClientConfiguration + ) -> None: + self.config = config + if config.fail_followup_job_creation: + raise Exception("Failed to create followup job") + super().__init__(original_file_name=original_file_name, remote_paths=remote_paths) + + +class LoadDummyJob(LoadDummyBaseJob, HasFollowupJobs): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: if self.config.create_followup_jobs and final_state == "completed": - new_job = NewReferenceJob( - file_name=self.file_name(), status="running", remote_path=self._file_name + new_job = DummyFollowupJob( + original_file_name=self.file_name(), + remote_paths=[self._file_name], + config=self.config, ) CREATED_FOLLOWUP_JOBS[new_job.job_id()] = new_job return [new_job] @@ -103,7 +113,9 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: JOBS: Dict[str, LoadDummyBaseJob] = {} -CREATED_FOLLOWUP_JOBS: Dict[str, NewLoadJob] = {} +CREATED_FOLLOWUP_JOBS: Dict[str, FollowupJob] = {} +CREATED_TABLE_CHAIN_FOLLOWUP_JOBS: Dict[str, FollowupJob] = {} +RETRIED_JOBS: Dict[str, LoadDummyBaseJob] = {} class DummyClient(JobClientBase, SupportsStagingDestination, WithStagingDataset): @@ -140,31 +152,41 @@ def update_stored_schema( ) return applied_update - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: job_id = FileStorage.get_file_name_from_file_path(file_path) - file_name = FileStorage.get_file_name_from_file_path(file_path) + if restore and job_id not in JOBS: + raise LoadJobNotExistsException(job_id) # return existing job if already there if job_id not in JOBS: - JOBS[job_id] = self._create_job(file_name) + JOBS[job_id] = self._create_job(file_path) else: job = JOBS[job_id] - if job.state == "retry": - job.retry() + # update config of existing job in case it was changed in tests + job.config = self.config + RETRIED_JOBS[job_id] = job return JOBS[job_id] - def restore_file_load(self, file_path: str) -> LoadJob: - job_id = FileStorage.get_file_name_from_file_path(file_path) - if job_id not in JOBS: - raise LoadJobNotExistsException(job_id) - return JOBS[job_id] - def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" + + # if sql job follow up is configure we schedule a merge job that will always fail + if self.config.fail_table_chain_followup_job_creation: + raise Exception("Failed to create table chain followup job") + if self.config.create_followup_table_chain_sql_jobs: + return [SqlMergeFollowupJob.from_table_chain(table_chain, self)] # type: ignore + if self.config.create_followup_table_chain_reference_jobs: + table_job_paths = [job.file_path for job in completed_table_chain_jobs] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + job = ReferenceFollowupJob(file_name, table_job_paths) + CREATED_TABLE_CHAIN_FOLLOWUP_JOBS[job.job_id()] = job + return [job] return [] def complete_load(self, load_id: str) -> None: @@ -190,7 +212,7 @@ def __exit__( pass def _create_job(self, job_id: str) -> LoadDummyBaseJob: - if NewReferenceJob.is_reference_job(job_id): + if ReferenceFollowupJob.is_reference_job(job_id): return LoadDummyBaseJob(job_id, config=self.config) else: return LoadDummyJob(job_id, config=self.config) diff --git a/dlt/destinations/impl/dummy/factory.py b/dlt/destinations/impl/dummy/factory.py index c2792fc432..8cf0408ec1 100644 --- a/dlt/destinations/impl/dummy/factory.py +++ b/dlt/destinations/impl/dummy/factory.py @@ -60,7 +60,9 @@ def adjust_capabilities( ) -> DestinationCapabilitiesContext: caps = super().adjust_capabilities(caps, config, naming) additional_formats: t.List[TLoaderFileFormat] = ( - ["reference"] if config.create_followup_jobs else [] # type:ignore[list-item] + ["reference"] + if (config.create_followup_jobs or config.create_followup_table_chain_reference_jobs) + else [] ) caps.preferred_loader_file_format = config.loader_file_format caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index 31b61c6cb1..ff3c8a59e1 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -7,6 +7,7 @@ from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations.impl.filesystem.typing import TCurrentDateTime, TExtraPlaceholders +from dlt.common.normalizers.naming.naming import NamingConvention if t.TYPE_CHECKING: from dlt.destinations.impl.filesystem.filesystem import FilesystemClient @@ -28,7 +29,7 @@ class filesystem(Destination[FilesystemDestinationClientConfiguration, "Filesyst spec = FilesystemDestinationClientConfiguration def _raw_capabilities(self) -> DestinationCapabilitiesContext: - return DestinationCapabilitiesContext.generic_capabilities( + caps = DestinationCapabilitiesContext.generic_capabilities( preferred_loader_file_format="jsonl", loader_file_format_adapter=loader_file_format_adapter, supported_table_formats=["delta"], @@ -37,6 +38,10 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: # loader file format) supported_merge_strategies=["upsert"], ) + caps.supported_loader_file_formats = list(caps.supported_loader_file_formats) + [ + "reference", + ] + return caps @property def client_class(self) -> t.Type["FilesystemClient"]: diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index ef4702b17d..f2466f25a2 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -9,7 +9,6 @@ import dlt from dlt.common import logger, time, json, pendulum -from dlt.common.utils import assert_min_pkg_version from dlt.common.storages.fsspec_filesystem import glob_files from dlt.common.typing import DictStrAny from dlt.common.schema import Schema, TSchemaTables, TTableSchema @@ -17,28 +16,29 @@ from dlt.common.storages import FileStorage, fsspec_from_config from dlt.common.storages.load_package import ( LoadJobInfo, - ParsedLoadJobFileName, TPipelineStateDoc, load_package as current_load_package, ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - NewLoadJob, + FollowupJob, TLoadJobState, - LoadJob, + RunnableLoadJob, JobClientBase, - FollowupJob, + HasFollowupJobs, WithStagingDataset, WithStateSync, StorageSchemaInfo, StateInfo, - DoNothingJob, - DoNothingFollowupJob, + LoadJob, ) from dlt.common.destination.exceptions import DestinationUndefinedEntity -from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob +from dlt.destinations.job_impl import ( + ReferenceFollowupJob, + FinalizedLoadJob, + FinalizedLoadJobWithFollowupJobs, +) from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration -from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations import path_utils from dlt.destinations.fs_client import FSClientBase @@ -46,31 +46,27 @@ FILENAME_SEPARATOR = "__" -class LoadFilesystemJob(LoadJob): +class FilesystemLoadJob(RunnableLoadJob): def __init__( self, - client: "FilesystemClient", - local_path: str, - load_id: str, - table: TTableSchema, + file_path: str, ) -> None: - self.client = client - self.table = table - self.is_local_filesystem = client.config.protocol == "file" + super().__init__(file_path) + self._job_client: FilesystemClient = None + + def run(self) -> None: # pick local filesystem pathlib or posix for buckets + self.is_local_filesystem = self._job_client.config.protocol == "file" self.pathlib = os.path if self.is_local_filesystem else posixpath - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.destination_file_name = path_utils.create_path( - client.config.layout, - file_name, - client.schema.name, - load_id, - current_datetime=client.config.current_datetime, + self._job_client.config.layout, + self._file_name, + self._job_client.schema.name, + self._load_id, + current_datetime=self._job_client.config.current_datetime, load_package_timestamp=dlt.current.load_package()["state"]["created_at"], - extra_placeholders=client.config.extra_placeholders, + extra_placeholders=self._job_client.config.extra_placeholders, ) # We would like to avoid failing for local filesystem where # deeply nested directory will not exist before writing a file. @@ -79,48 +75,26 @@ def __init__( # remote_path = f"{client.config.protocol}://{posixpath.join(dataset_path, destination_file_name)}" remote_path = self.make_remote_path() if self.is_local_filesystem: - client.fs_client.makedirs(self.pathlib.dirname(remote_path), exist_ok=True) - client.fs_client.put_file(local_path, remote_path) + self._job_client.fs_client.makedirs(self.pathlib.dirname(remote_path), exist_ok=True) + self._job_client.fs_client.put_file(self._file_path, remote_path) def make_remote_path(self) -> str: """Returns path on the remote filesystem to which copy the file, without scheme. For local filesystem a native path is used""" # path.join does not normalize separators and available # normalization functions are very invasive and may string the trailing separator return self.pathlib.join( # type: ignore[no-any-return] - self.client.dataset_path, + self._job_client.dataset_path, path_utils.normalize_path_sep(self.pathlib, self.destination_file_name), ) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - - -class DeltaLoadFilesystemJob(NewReferenceJob): - def __init__( - self, - client: "FilesystemClient", - table: TTableSchema, - table_jobs: Sequence[LoadJobInfo], - ) -> None: - self.client = client - self.table = table - self.table_jobs = table_jobs - ref_file_name = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" - ).file_name() +class DeltaLoadFilesystemJob(FilesystemLoadJob): + def __init__(self, file_path: str) -> None: super().__init__( - file_name=ref_file_name, - status="running", - remote_path=self.client.make_remote_uri(self.make_remote_path()), + file_path=file_path, ) - self.write() - - def write(self) -> None: + def run(self) -> None: from dlt.common.libs.pyarrow import pyarrow as pa from dlt.common.libs.deltalake import ( DeltaTable, @@ -130,21 +104,20 @@ def write(self) -> None: try_get_deltatable, ) - assert_min_pkg_version( - pkg_name="pyarrow", - version="17.0.0", - msg="`pyarrow>=17.0.0` is needed for `delta` table format on `filesystem` destination.", - ) - # create Arrow dataset from Parquet files - file_paths = [job.file_path for job in self.table_jobs] + file_paths = ReferenceFollowupJob.resolve_references(self._file_path) arrow_ds = pa.dataset.dataset(file_paths) # create Delta table object - dt_path = self.client.make_remote_uri(self.make_remote_path()) - storage_options = _deltalake_storage_options(self.client.config) + dt_path = self._job_client.make_remote_uri( + self._job_client.get_table_dir(self.load_table_name) + ) + storage_options = _deltalake_storage_options(self._job_client.config) dt = try_get_deltatable(dt_path, storage_options=storage_options) + # get partition columns + part_cols = get_columns_names_with_prop(self._load_table, "partition") + # explicitly check if there is data # (https://github.com/delta-io/delta-rs/issues/2686) if arrow_ds.head(1).num_rows == 0: @@ -154,20 +127,22 @@ def write(self) -> None: table_uri=dt_path, schema=ensure_delta_compatible_arrow_schema(arrow_ds.schema), mode="overwrite", + partition_by=part_cols, + storage_options=storage_options, ) return arrow_rbr = arrow_ds.scanner().to_reader() # RecordBatchReader - if self.table["write_disposition"] == "merge" and dt is not None: - assert self.table["x-merge-strategy"] in self.client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] + if self._load_table["write_disposition"] == "merge" and dt is not None: + assert self._load_table["x-merge-strategy"] in self._job_client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] - if self.table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item] - if "parent" in self.table: - unique_column = get_first_column_name_with_prop(self.table, "unique") + if self._load_table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item] + if "parent" in self._load_table: + unique_column = get_first_column_name_with_prop(self._load_table, "unique") predicate = f"target.{unique_column} = source.{unique_column}" else: - primary_keys = get_columns_names_with_prop(self.table, "primary_key") + primary_keys = get_columns_names_with_prop(self._load_table, "primary_key") predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys]) qry = ( @@ -187,26 +162,22 @@ def write(self) -> None: write_delta_table( table_or_uri=dt_path if dt is None else dt, data=arrow_rbr, - write_disposition=self.table["write_disposition"], + write_disposition=self._load_table["write_disposition"], + partition_by=part_cols, storage_options=storage_options, ) - def make_remote_path(self) -> str: - # directory path, not file path - return self.client.get_table_dir(self.table["name"]) - - def state(self) -> TLoadJobState: - return "completed" - -class FollowupFilesystemJob(FollowupJob, LoadFilesystemJob): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: +class FilesystemLoadJobWithFollowup(HasFollowupJobs, FilesystemLoadJob): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: jobs = super().create_followup_jobs(final_state) - if final_state == "completed": - ref_job = NewReferenceJob( - file_name=self.file_name(), - status="running", - remote_path=self.client.make_remote_uri(self.make_remote_path()), + if self._load_table.get("table_format") == "delta": + # delta table jobs only require table chain followup jobs + pass + elif final_state == "completed": + ref_job = ReferenceFollowupJob( + original_file_name=self.file_name(), + remote_paths=[self._job_client.make_remote_uri(self.make_remote_path())], ) jobs.append(ref_job) return jobs @@ -287,7 +258,7 @@ def drop_tables(self, *tables: str, delete_schema: bool = True) -> None: self._delete_file(filename) def truncate_tables(self, table_names: List[str]) -> None: - """Truncate a set of tables with given `table_names`""" + """Truncate a set of regular tables with given `table_names`""" table_dirs = set(self.get_table_dirs(table_names)) table_prefixes = [self.get_table_prefix(t) for t in table_names] for table_dir in table_dirs: @@ -332,10 +303,13 @@ def update_stored_schema( return expected_update - def get_table_dir(self, table_name: str) -> str: + def get_table_dir(self, table_name: str, remote: bool = False) -> str: # dlt tables do not respect layout (for now) table_prefix = self.get_table_prefix(table_name) - return self.pathlib.dirname(table_prefix) # type: ignore[no-any-return] + table_dir: str = self.pathlib.dirname(table_prefix) + if remote: + table_dir = self.make_remote_uri(table_dir) + return table_dir def get_table_prefix(self, table_name: str) -> str: # dlt tables do not respect layout (for now) @@ -351,9 +325,9 @@ def get_table_prefix(self, table_name: str) -> str: self.dataset_path, path_utils.normalize_path_sep(self.pathlib, table_prefix) ) - def get_table_dirs(self, table_names: Iterable[str]) -> List[str]: + def get_table_dirs(self, table_names: Iterable[str], remote: bool = False) -> List[str]: """Gets directories where table data is stored.""" - return [self.get_table_dir(t) for t in table_names] + return [self.get_table_dir(t, remote=remote) for t in table_names] def list_table_files(self, table_name: str) -> List[str]: """gets list of files associated with one table""" @@ -383,22 +357,25 @@ def list_files_with_prefixes(self, table_dir: str, prefixes: List[str]) -> List[ def is_storage_initialized(self) -> bool: return self.fs_client.exists(self.pathlib.join(self.dataset_path, INIT_FILE_NAME)) # type: ignore[no-any-return] - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: # skip the state table, we create a jsonl file in the complete_load step # this does not apply to scenarios where we are using filesystem as staging # where we want to load the state the regular way if table["name"] == self.schema.state_table_name and not self.config.as_staging: - return DoNothingJob(file_path) + return FinalizedLoadJob(file_path) if table.get("table_format") == "delta": import dlt.common.libs.deltalake # assert dependencies are installed - return DoNothingFollowupJob(file_path) - - cls = FollowupFilesystemJob if self.config.as_staging else LoadFilesystemJob - return cls(self, file_path, load_id, table) + # a reference job for a delta table indicates a table chain followup job + if ReferenceFollowupJob.is_reference_job(file_path): + return DeltaLoadFilesystemJob(file_path) + # otherwise just continue + return FinalizedLoadJobWithFollowupJobs(file_path) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") + cls = FilesystemLoadJobWithFollowup if self.config.as_staging else FilesystemLoadJob + return cls(file_path) def make_remote_uri(self, remote_path: str) -> str: """Returns uri to the remote filesystem to which copy the file""" @@ -601,26 +578,18 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: - def get_table_jobs( - table_jobs: Sequence[LoadJobInfo], table_name: str - ) -> Sequence[LoadJobInfo]: - return [job for job in table_jobs if job.job_file_info.table_name == table_name] - + ) -> List[FollowupJob]: assert completed_table_chain_jobs is not None jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) - table_format = table_chain[0].get("table_format") - if table_format == "delta": - delta_jobs = [ - DeltaLoadFilesystemJob( - self, - table=self.prepare_load_table(table["name"]), - table_jobs=get_table_jobs(completed_table_chain_jobs, table["name"]), - ) - for table in table_chain - ] - jobs.extend(delta_jobs) - + if table_chain[0].get("table_format") == "delta": + for table in table_chain: + table_job_paths = [ + job.file_path + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table["name"] + ] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + jobs.append(ReferenceFollowupJob(file_name, table_job_paths)) return jobs diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8265e50fbf..78a37952b9 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -34,10 +34,11 @@ from dlt.common.destination.reference import ( JobClientBase, WithStateSync, - LoadJob, + RunnableLoadJob, StorageSchemaInfo, StateInfo, TLoadJobState, + LoadJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -69,7 +70,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -686,17 +687,12 @@ def complete_load(self, load_id: str) -> None: write_disposition=write_disposition, ) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - return LoadLanceDBJob( - self.schema, - table, - file_path, + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + return LanceDBLoadJob( + file_path=file_path, type_mapper=self.type_mapper, - db_client=self.db_client, - client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name(table["name"]), ) @@ -705,66 +701,56 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LoadLanceDBJob(LoadJob): +class LanceDBLoadJob(RunnableLoadJob): arrow_schema: TArrowSchema def __init__( self, - schema: Schema, - table_schema: TTableSchema, - local_path: str, + file_path: str, type_mapper: LanceDBTypeMapper, - db_client: DBConnection, - client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.schema: Schema = schema - self.table_schema: TTableSchema = table_schema - self.db_client: DBConnection = db_client - self.type_mapper: TypeMapper = type_mapper - self.table_name: str = table_schema["name"] - self.fq_table_name: str = fq_table_name - self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) - self.embedding_model_func: TextEmbeddingFunction = model_func - self.embedding_model_dimensions: int = client_config.embedding_model_dimensions - self.id_field_name: str = client_config.id_field_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition", "append") + super().__init__(file_path) + self._type_mapper: TypeMapper = type_mapper + self._fq_table_name: str = fq_table_name + self._model_func = model_func + self._job_client: "LanceDBClient" = None + + def run(self) -> None: + self._db_client: DBConnection = self._job_client.db_client + self._embedding_model_func: TextEmbeddingFunction = self._model_func + self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions + self._id_field_name: str = self._job_client.config.id_field_name + + unique_identifiers: Sequence[str] = list_merge_identifiers(self._load_table) + write_disposition: TWriteDisposition = cast( + TWriteDisposition, self._load_table.get("write_disposition", "append") ) - with FileStorage.open_zipsafe_ro(local_path) as f: + with FileStorage.open_zipsafe_ro(self._file_path) as f: records: List[DictStrAny] = [json.loads(line) for line in f] - if self.table_schema not in self.schema.dlt_tables(): + if self._load_table not in self._schema.dlt_tables(): for record in records: # Add reserved ID fields. uuid_id = ( - generate_uuid(record, self.unique_identifiers, self.fq_table_name) - if self.unique_identifiers + generate_uuid(record, unique_identifiers, self._fq_table_name) + if unique_identifiers else str(uuid.uuid4()) ) - record.update({self.id_field_name: uuid_id}) + record.update({self._id_field_name: uuid_id}) # LanceDB expects all fields in the target arrow table to be present in the data payload. # We add and set these missing fields, that are fields not present in the target schema, to NULL. - missing_fields = set(self.table_schema["columns"]) - set(record) + missing_fields = set(self._load_table["columns"]) - set(record) for field in missing_fields: record[field] = None upload_batch( records, - db_client=db_client, - table_name=self.fq_table_name, - write_disposition=self.write_disposition, - id_field_name=self.id_field_name, + db_client=self._db_client, + table_name=self._fq_table_name, + write_disposition=write_disposition, + id_field_name=self._id_field_name, ) - - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() diff --git a/dlt/destinations/impl/motherduck/factory.py b/dlt/destinations/impl/motherduck/factory.py index a9bab96d08..0f4218f7cb 100644 --- a/dlt/destinations/impl/motherduck/factory.py +++ b/dlt/destinations/impl/motherduck/factory.py @@ -33,10 +33,11 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.is_max_query_length_in_bytes = True caps.max_text_data_type_length = 1024 * 1024 * 1024 caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False + caps.supports_ddl_transactions = True caps.alter_add_multi_column = False caps.supports_truncate_command = False caps.supported_merge_strategies = ["delete-insert", "scd2"] + caps.max_parallel_load_jobs = 8 return caps diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index ec4a54d6f7..a67423a873 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,12 +1,12 @@ from typing import Dict, Optional, Sequence, List, Any from dlt.common.exceptions import TerminalValueError -from dlt.common.destination.reference import NewLoadJob +from dlt.common.destination.reference import FollowupJob from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -85,7 +85,7 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class MsSqlStagingCopyJob(SqlStagingCopyJob): +class MsSqlStagingCopyJob(SqlStagingCopyFollowupJob): @classmethod def generate_sql( cls, @@ -110,7 +110,7 @@ def generate_sql( return sql -class MsSqlMergeJob(SqlMergeJob): +class MsSqlMergeJob(SqlMergeFollowupJob): @classmethod def gen_key_table_clauses( cls, @@ -127,7 +127,7 @@ def gen_key_table_clauses( f" {staging_root_table_name} WHERE" f" {' OR '.join([c.format(d=root_table_name,s=staging_root_table_name) for c in key_clauses])})" ] - return SqlMergeJob.gen_key_table_clauses( + return SqlMergeFollowupJob.gen_key_table_clauses( root_table_name, staging_root_table_name, key_clauses, for_delete ) @@ -137,7 +137,7 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: - name = SqlMergeJob._new_temp_table_name(name_prefix, sql_client) + name = SqlMergeFollowupJob._new_temp_table_name(name_prefix, sql_client) return "#" + name @@ -160,7 +160,7 @@ def __init__( self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {} self.type_mapper = MsSqlTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( @@ -189,7 +189,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: if self.config.replace_strategy == "staging-optimized": return [MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index f47549fc4f..5ae5f27a6e 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -6,14 +6,20 @@ DestinationInvalidFileFormat, DestinationTerminalException, ) -from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState +from dlt.common.destination.reference import ( + HasFollowupJobs, + RunnableLoadJob, + FollowupJob, + LoadJob, + TLoadJobState, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration @@ -85,7 +91,7 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class PostgresStagingCopyJob(SqlStagingCopyJob): +class PostgresStagingCopyJob(SqlStagingCopyFollowupJob): @classmethod def generate_sql( cls, @@ -110,21 +116,24 @@ def generate_sql( return sql -class PostgresCsvCopyJob(LoadJob, FollowupJob): - def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient") -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - config = client.config - sql_client = client.sql_client - csv_format = config.csv_format or CsvFormatConfiguration() - table_name = table["name"] +class PostgresCsvCopyJob(RunnableLoadJob, HasFollowupJobs): + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: PostgresClient = None + + def run(self) -> None: + self._config = self._job_client.config + sql_client = self._job_client.sql_client + csv_format = self._config.csv_format or CsvFormatConfiguration() + table_name = self.load_table_name sep = csv_format.delimiter if csv_format.on_error_continue: logger.warning( - f"When processing {file_path} on table {table_name} Postgres csv reader does not" - " support on_error_continue" + f"When processing {self._file_path} on table {table_name} Postgres csv reader does" + " not support on_error_continue" ) - with FileStorage.open_zipsafe_ro(file_path, "rb") as f: + with FileStorage.open_zipsafe_ro(self._file_path, "rb") as f: if csv_format.include_header: # all headers in first line headers_row: str = f.readline().decode(csv_format.encoding).strip() @@ -132,12 +141,12 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" else: # read first row to figure out the headers split_first_row: str = f.readline().decode(csv_format.encoding).strip().split(sep) - split_headers = list(client.schema.get_table_columns(table_name).keys()) + split_headers = list(self._job_client.schema.get_table_columns(table_name).keys()) if len(split_first_row) > len(split_headers): raise DestinationInvalidFileFormat( "postgres", "csv", - file_path, + self._file_path, f"First row {split_first_row} has more rows than columns {split_headers} in" f" table {table_name}", ) @@ -158,7 +167,7 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" split_columns = [] # detect columns with NULL to use in FORCE NULL # detect headers that are not in columns - for col in client.schema.get_table_columns(table_name).values(): + for col in self._job_client.schema.get_table_columns(table_name).values(): norm_col = sql_client.escape_column_name(col["name"], escape=True) split_columns.append(norm_col) if norm_col in split_headers and col.get("nullable", True): @@ -168,7 +177,7 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" raise DestinationInvalidFileFormat( "postgres", "csv", - file_path, + self._file_path, f"Following headers {split_unknown_headers} cannot be matched to columns" f" {split_columns} of table {table_name}.", ) @@ -196,12 +205,6 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" with sql_client.native_connection.cursor() as cursor: cursor.copy_expert(copy_sql, f, size=8192) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class PostgresClient(InsertValuesJobClient): def __init__( @@ -222,10 +225,12 @@ def __init__( self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = PostgresTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job and file_path.endswith("csv"): - job = PostgresCsvCopyJob(table, file_path, self) + job = PostgresCsvCopyJob(file_path) return job def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -241,7 +246,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: if self.config.replace_strategy == "staging-optimized": return [PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 28d7388701..65019c6626 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -13,12 +13,19 @@ version_table, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.reference import ( + TLoadJobState, + RunnableLoadJob, + JobClientBase, + WithStateSync, + LoadJob, +) from dlt.common.destination.exceptions import DestinationUndefinedEntity + from dlt.common.storages import FileStorage from dlt.common.time import precise_time -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo from dlt.destinations.utils import get_pipeline_state_query_columns @@ -30,49 +37,49 @@ from qdrant_client.http.exceptions import UnexpectedResponse -class LoadQdrantJob(LoadJob): +class QDrantLoadJob(RunnableLoadJob): def __init__( self, - table_schema: TTableSchema, - local_path: str, - db_client: QC, + file_path: str, client_config: QdrantClientConfiguration, collection_name: str, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.db_client = db_client - self.collection_name = collection_name - self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) - self.unique_identifiers = self._list_unique_identifiers(table_schema) - self.config = client_config - - with FileStorage.open_zipsafe_ro(local_path) as f: + super().__init__(file_path) + self._collection_name = collection_name + self._config = client_config + self._job_client: "QdrantClient" = None + + def run(self) -> None: + embedding_fields = get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) + unique_identifiers = self._list_unique_identifiers(self._load_table) + with FileStorage.open_zipsafe_ro(self._file_path) as f: ids: List[str] docs, payloads, ids = [], [], [] for line in f: data = json.loads(line) point_id = ( - self._generate_uuid(data, self.unique_identifiers, self.collection_name) - if self.unique_identifiers + self._generate_uuid(data, unique_identifiers, self._collection_name) + if unique_identifiers else str(uuid.uuid4()) ) payloads.append(data) ids.append(point_id) - if len(self.embedding_fields) > 0: - docs.append(self._get_embedding_doc(data)) + if len(embedding_fields) > 0: + docs.append(self._get_embedding_doc(data, embedding_fields)) - if len(self.embedding_fields) > 0: - embedding_model = db_client._get_or_init_model(db_client.embedding_model_name) + if len(embedding_fields) > 0: + embedding_model = self._job_client.db_client._get_or_init_model( + self._job_client.db_client.embedding_model_name + ) embeddings = list( embedding_model.embed( docs, - batch_size=self.config.embedding_batch_size, - parallel=self.config.embedding_parallelism, + batch_size=self._config.embedding_batch_size, + parallel=self._config.embedding_parallelism, ) ) - vector_name = db_client.get_vector_field_name() + vector_name = self._job_client.db_client.get_vector_field_name() embeddings = [{vector_name: embedding.tolist()} for embedding in embeddings] else: embeddings = [{}] * len(ids) @@ -80,7 +87,7 @@ def __init__( self._upload_data(vectors=embeddings, ids=ids, payloads=payloads) - def _get_embedding_doc(self, data: Dict[str, Any]) -> str: + def _get_embedding_doc(self, data: Dict[str, Any], embedding_fields: List[str]) -> str: """Returns a document to generate embeddings for. Args: @@ -89,7 +96,7 @@ def _get_embedding_doc(self, data: Dict[str, Any]) -> str: Returns: str: A concatenated string of all the fields intended for embedding. """ - doc = "\n".join(str(data[key]) for key in self.embedding_fields) + doc = "\n".join(str(data[key]) for key in embedding_fields) return doc def _list_unique_identifiers(self, table_schema: TTableSchema) -> Sequence[str]: @@ -117,14 +124,14 @@ def _upload_data( vectors (Iterable[Any]): Embeddings to be uploaded to the collection payloads (Iterable[Any]): Payloads to be uploaded to the collection """ - self.db_client.upload_collection( - self.collection_name, + self._job_client.db_client.upload_collection( + self._collection_name, ids=ids, payload=payloads, vectors=vectors, - parallel=self.config.upload_parallelism, - batch_size=self.config.upload_batch_size, - max_retries=self.config.upload_max_retries, + parallel=self._config.upload_parallelism, + batch_size=self._config.upload_batch_size, + max_retries=self._config.upload_max_retries, ) def _generate_uuid( @@ -143,12 +150,6 @@ def _generate_uuid( data_id = "_".join(str(data[key]) for key in unique_identifiers) return str(uuid.uuid5(uuid.NAMESPACE_DNS, collection_name + data_id)) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class QdrantClient(JobClientBase, WithStateSync): """Qdrant Destination Handler""" @@ -438,18 +439,15 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI return None raise - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - return LoadQdrantJob( - table, + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + return QDrantLoadJob( file_path, - db_client=self.db_client, client_config=self.config, collection_name=self._make_qualified_collection_name(table["name"]), ) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - def complete_load(self, load_id: str) -> None: values = [load_id, self.schema.name, 0, str(pendulum.now()), self.schema.version_hash] assert len(values) == len(self.loads_collection_properties) diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 8eacc76d11..81abd57803 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -14,9 +14,10 @@ from dlt.common.destination.reference import ( - NewLoadJob, + FollowupJob, CredentialsConfiguration, SupportsStagingDestination, + LoadJob, ) from dlt.common.data_types import TDataType from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -27,12 +28,12 @@ from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException -from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob +from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.redshift.configuration import RedshiftClientConfiguration -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -123,16 +124,16 @@ def _maybe_make_terminal_exception_from_data_error( class RedshiftCopyFileLoadJob(CopyRemoteFileLoadJob): def __init__( self, - table: TTableSchema, file_path: str, - sql_client: SqlClientBase[Any], staging_credentials: Optional[CredentialsConfiguration] = None, staging_iam_role: str = None, ) -> None: + super().__init__(file_path, staging_credentials) self._staging_iam_role = staging_iam_role - super().__init__(table, file_path, sql_client, staging_credentials) + self._job_client: "RedshiftClient" = None - def execute(self, table: TTableSchema, bucket_path: str) -> None: + def run(self) -> None: + self._sql_client = self._job_client.sql_client # we assume s3 credentials where provided for the staging credentials = "" if self._staging_iam_role: @@ -148,11 +149,11 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: ) # get format - ext = os.path.splitext(bucket_path)[1][1:] + ext = os.path.splitext(self._bucket_path)[1][1:] file_type = "" dateformat = "" compression = "" - if table_schema_has_type(table, "time"): + if table_schema_has_type(self._load_table, "time"): raise LoadJobTerminalException( self.file_name(), f"Redshift cannot load TIME columns from {ext} files. Switch to direct INSERT file" @@ -160,7 +161,7 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: " `datetime.datetime`", ) if ext == "jsonl": - if table_schema_has_type(table, "binary"): + if table_schema_has_type(self._load_table, "binary"): raise LoadJobTerminalException( self.file_name(), "Redshift cannot load VARBYTE columns from json files. Switch to parquet to" @@ -170,7 +171,7 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: dateformat = "dateformat 'auto' timeformat 'auto'" compression = "GZIP" elif ext == "parquet": - if table_schema_has_type_with_precision(table, "binary"): + if table_schema_has_type_with_precision(self._load_table, "binary"): raise LoadJobTerminalException( self.file_name(), f"Redshift cannot load fixed width VARBYTE columns from {ext} files. Switch to" @@ -179,7 +180,7 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: file_type = "PARQUET" # if table contains complex types then SUPER field will be used. # https://docs.aws.amazon.com/redshift/latest/dg/ingest-super.html - if table_schema_has_type(table, "complex"): + if table_schema_has_type(self._load_table, "complex"): file_type += " SERIALIZETOJSON" else: raise ValueError(f"Unsupported file type {ext} for Redshift.") @@ -187,19 +188,15 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: with self._sql_client.begin_transaction(): # TODO: if we ever support csv here remember to add column names to COPY self._sql_client.execute_sql(f""" - COPY {self._sql_client.make_qualified_table_name(table['name'])} - FROM '{bucket_path}' + COPY {self._sql_client.make_qualified_table_name(self.load_table_name)} + FROM '{self._bucket_path}' {file_type} {dateformat} {compression} {credentials} MAXERROR 0;""") - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - -class RedshiftMergeJob(SqlMergeJob): +class RedshiftMergeJob(SqlMergeFollowupJob): @classmethod def gen_key_table_clauses( cls, @@ -218,7 +215,7 @@ def gen_key_table_clauses( f" {staging_root_table_name} WHERE" f" {' OR '.join([c.format(d=root_table_name,s=staging_root_table_name) for c in key_clauses])})" ] - return SqlMergeJob.gen_key_table_clauses( + return SqlMergeFollowupJob.gen_key_table_clauses( root_table_name, staging_root_table_name, key_clauses, for_delete ) @@ -241,7 +238,7 @@ def __init__( self.config: RedshiftClientConfiguration = config self.type_mapper = RedshiftTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -255,17 +252,17 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" - job = super().start_file_load(table, file_path, load_id) + job = super().create_load_job(table, file_path, load_id, restore) if not job: - assert NewReferenceJob.is_reference_job( + assert ReferenceFollowupJob.is_reference_job( file_path ), "Redshift must use staging to load files" job = RedshiftCopyFileLoadJob( - table, file_path, - self.sql_client, staging_credentials=self.config.staging_config.credentials, staging_iam_role=self.config.staging_iam_role, ) diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index bf175ba911..904b524791 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -4,10 +4,9 @@ from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - FollowupJob, - NewLoadJob, - TLoadJobState, + HasFollowupJobs, LoadJob, + RunnableLoadJob, CredentialsConfiguration, SupportsStagingDestination, ) @@ -24,13 +23,13 @@ from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import TLoaderFileFormat from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper @@ -79,63 +78,68 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class SnowflakeLoadJob(LoadJob, FollowupJob): +class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, file_path: str, - table_name: str, - load_id: str, - client: SnowflakeSqlClient, config: SnowflakeClientConfiguration, stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) + super().__init__(file_path) + self._keep_staged_files = keep_staged_files + self._staging_credentials = staging_credentials + self._config = config + self._stage_name = stage_name + self._job_client: "SnowflakeClient" = None + + def run(self) -> None: + self._sql_client = self._job_client.sql_client + # resolve reference - is_local_file = not NewReferenceJob.is_reference_job(file_path) - file_url = file_path if is_local_file else NewReferenceJob.resolve_reference(file_path) + is_local_file = not ReferenceFollowupJob.is_reference_job(self._file_path) + file_url = ( + self._file_path + if is_local_file + else ReferenceFollowupJob.resolve_reference(self._file_path) + ) # take file name file_name = FileStorage.get_file_name_from_file_path(file_url) file_format = file_name.rsplit(".", 1)[-1] - qualified_table_name = client.make_qualified_table_name(table_name) + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) # this means we have a local file stage_file_path: str = "" if is_local_file: - if not stage_name: + if not self._stage_name: # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - stage_name = client.make_qualified_table_name("%" + table_name) - stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' + self._stage_name = self._sql_client.make_qualified_table_name( + "%" + self.load_table_name + ) + stage_file_path = f'@{self._stage_name}/"{self._load_id}"/{file_name}' copy_sql = self.gen_copy_sql( file_url, qualified_table_name, file_format, # type: ignore[arg-type] - client.capabilities.generates_case_sensitive_identifiers(), - stage_name, + self._sql_client.capabilities.generates_case_sensitive_identifiers(), + self._stage_name, stage_file_path, - staging_credentials, - config.csv_format, + self._staging_credentials, + self._config.csv_format, ) - with client.begin_transaction(): + with self._sql_client.begin_transaction(): # PUT and COPY in one tx if local file, otherwise only copy if is_local_file: - client.execute_sql( - f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' - " AUTO_COMPRESS = FALSE" + self._sql_client.execute_sql( + f'PUT file://{self._file_path} @{self._stage_name}/"{self._load_id}" OVERWRITE' + " = TRUE, AUTO_COMPRESS = FALSE" ) - client.execute_sql(copy_sql) - if stage_file_path and not keep_staged_files: - client.execute_sql(f"REMOVE {stage_file_path}") - - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() + self._sql_client.execute_sql(copy_sql) + if stage_file_path and not self._keep_staged_files: + self._sql_client.execute_sql(f"REMOVE {stage_file_path}") @classmethod def gen_copy_sql( @@ -267,15 +271,14 @@ def __init__( self.sql_client: SnowflakeSqlClient = sql_client # type: ignore self.type_mapper = SnowflakeTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = SnowflakeLoadJob( file_path, - table["name"], - load_id, - self.sql_client, self.config, stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, @@ -285,9 +288,6 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> ) return job - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - def _make_add_column_sql( self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None ) -> List[str]: diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 408bfc2b53..d1b38f73bd 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -5,10 +5,7 @@ from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( - SupportsStagingDestination, - NewLoadJob, -) +from dlt.common.destination.reference import SupportsStagingDestination, FollowupJob, LoadJob from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint from dlt.common.schema.utils import ( @@ -22,9 +19,12 @@ AzureServicePrincipalCredentialsWithoutDefaults, ) -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob +from dlt.destinations.job_client_impl import ( + SqlJobClientBase, + CopyRemoteFileLoadJob, +) from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.mssql.mssql import ( @@ -131,7 +131,7 @@ def _get_columstore_valid_column(self, c: TColumnSchema) -> TColumnSchema: def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: return SqlJobClientBase._create_replace_followup_jobs(self, table_chain) def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: @@ -158,16 +158,16 @@ def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSc table[TABLE_INDEX_TYPE_HINT] = self.config.default_table_index_type # type: ignore[typeddict-unknown-key] return table - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: - assert NewReferenceJob.is_reference_job( + assert ReferenceFollowupJob.is_reference_job( file_path ), "Synapse must use staging to load files" job = SynapseCopyFileLoadJob( - table, file_path, - self.sql_client, self.config.staging_config.credentials, # type: ignore[arg-type] self.config.staging_use_msi, ) @@ -177,22 +177,21 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> class SynapseCopyFileLoadJob(CopyRemoteFileLoadJob): def __init__( self, - table: TTableSchema, file_path: str, - sql_client: SqlClientBase[Any], staging_credentials: Optional[ Union[AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults] ] = None, staging_use_msi: bool = False, ) -> None: self.staging_use_msi = staging_use_msi - super().__init__(table, file_path, sql_client, staging_credentials) + super().__init__(file_path, staging_credentials) - def execute(self, table: TTableSchema, bucket_path: str) -> None: + def run(self) -> None: + self._sql_client = self._job_client.sql_client # get format - ext = os.path.splitext(bucket_path)[1][1:] + ext = os.path.splitext(self._bucket_path)[1][1:] if ext == "parquet": - if table_schema_has_type(table, "time"): + if table_schema_has_type(self._load_table, "time"): # Synapse interprets Parquet TIME columns as bigint, resulting in # an incompatibility error. raise LoadJobTerminalException( @@ -216,8 +215,8 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: (AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults), ) azure_storage_account_name = staging_credentials.azure_storage_account_name - https_path = self._get_https_path(bucket_path, azure_storage_account_name) - table_name = table["name"] + https_path = self._get_https_path(self._bucket_path, azure_storage_account_name) + table_name = self._load_table["name"] if self.staging_use_msi: credential = "IDENTITY = 'Managed Identity'" @@ -252,10 +251,6 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: """) self._sql_client.execute_sql(sql) - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - def _get_https_path(self, bucket_path: str, storage_account_name: str) -> str: """ Converts a path in the form of az:/// to diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index dfbf83d7e5..b8bf3d62c6 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -38,11 +38,17 @@ version_table, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.reference import ( + TLoadJobState, + RunnableLoadJob, + JobClientBase, + WithStateSync, + LoadJob, +) from dlt.common.storages import FileStorage from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo from dlt.destinations.impl.weaviate.configuration import WeaviateClientConfiguration from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateGrpcError @@ -143,34 +149,31 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore -class LoadWeaviateJob(LoadJob): +class LoadWeaviateJob(RunnableLoadJob): def __init__( self, - schema: Schema, - table_schema: TTableSchema, - local_path: str, - db_client: weaviate.Client, - client_config: WeaviateClientConfiguration, + file_path: str, class_name: str, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.client_config = client_config - self.db_client = db_client - self.table_name = table_schema["name"] - self.class_name = class_name - self.unique_identifiers = self.list_unique_identifiers(table_schema) + super().__init__(file_path) + self._job_client: WeaviateClient = None + self._class_name = class_name + + def run(self) -> None: + self._db_client = self._job_client.db_client + self._client_config = self._job_client.config + self.unique_identifiers = self.list_unique_identifiers(self._load_table) self.complex_indices = [ i - for i, field in schema.get_table_columns(self.table_name).items() + for i, field in self._schema.get_table_columns(self.load_table_name).items() if field["data_type"] == "complex" ] self.date_indices = [ i - for i, field in schema.get_table_columns(self.table_name).items() + for i, field in self._schema.get_table_columns(self.load_table_name).items() if field["data_type"] == "date" ] - with FileStorage.open_zipsafe_ro(local_path) as f: + with FileStorage.open_zipsafe_ro(self._file_path) as f: self.load_batch(f) @wrap_weaviate_error @@ -188,15 +191,15 @@ def check_batch_result(results: List[StrAny]) -> None: if "error" in result["result"]["errors"]: raise WeaviateGrpcError(result["result"]["errors"]) - with self.db_client.batch( - batch_size=self.client_config.batch_size, - timeout_retries=self.client_config.batch_retries, - connection_error_retries=self.client_config.batch_retries, + with self._db_client.batch( + batch_size=self._client_config.batch_size, + timeout_retries=self._client_config.batch_retries, + connection_error_retries=self._client_config.batch_retries, weaviate_error_retries=weaviate.WeaviateErrorRetryConf( - self.client_config.batch_retries + self._client_config.batch_retries ), - consistency_level=weaviate.ConsistencyLevel[self.client_config.batch_consistency], - num_workers=self.client_config.batch_workers, + consistency_level=weaviate.ConsistencyLevel[self._client_config.batch_consistency], + num_workers=self._client_config.batch_workers, callback=check_batch_result, ) as batch: for line in f: @@ -209,11 +212,11 @@ def check_batch_result(results: List[StrAny]) -> None: if key in data: data[key] = ensure_pendulum_datetime(data[key]).isoformat() if self.unique_identifiers: - uuid = self.generate_uuid(data, self.unique_identifiers, self.class_name) + uuid = self.generate_uuid(data, self.unique_identifiers, self._class_name) else: uuid = None - batch.add_data_object(data, self.class_name, uuid=uuid) + batch.add_data_object(data, self._class_name, uuid=uuid) def list_unique_identifiers(self, table_schema: TTableSchema) -> Sequence[str]: if table_schema.get("write_disposition") == "merge": @@ -228,12 +231,6 @@ def generate_uuid( data_id = "_".join([str(data[key]) for key in unique_identifiers]) return generate_uuid5(data_id, class_name) # type: ignore - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class WeaviateClient(JobClientBase, WithStateSync): """Weaviate client implementation.""" @@ -677,19 +674,14 @@ def _make_property_schema( **extra_kv, } - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: return LoadWeaviateJob( - self.schema, - table, file_path, - db_client=self.db_client, - client_config=self.config, class_name=self.make_qualified_class_name(table["name"]), ) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - @wrap_weaviate_error def complete_load(self, load_id: str) -> None: # corresponds to order of the columns in loads_table() diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 652d13f556..6ccc65705b 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -2,35 +2,31 @@ import abc from typing import Any, Iterator, List -from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState +from dlt.common.destination.reference import RunnableLoadJob, HasFollowupJobs, LoadJob from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage from dlt.common.utils import chunks from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs +from dlt.destinations.job_client_impl import SqlJobClientWithStaging, SqlJobClientBase -class InsertValuesLoadJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: SqlClientBase[Any]) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - self._sql_client = sql_client +class InsertValuesLoadJob(RunnableLoadJob, HasFollowupJobs): + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: "SqlJobClientBase" = None + + def run(self) -> None: # insert file content immediately + self._sql_client = self._job_client.sql_client + with self._sql_client.begin_transaction(): for fragments in self._insert( - sql_client.make_qualified_table_name(table_name), file_path + self._sql_client.make_qualified_table_name(self.load_table_name), self._file_path ): self._sql_client.execute_fragments(fragments) - def state(self) -> TLoadJobState: - # this job is always done - return "completed" - - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[str]]: # WARNING: maximum redshift statement is 16MB https://docs.aws.amazon.com/redshift/latest/dg/c_redshift-sql.html # the procedure below will split the inserts into max_query_length // 2 packs @@ -101,27 +97,12 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st class InsertValuesJobClient(SqlJobClientWithStaging): - def restore_file_load(self, file_path: str) -> LoadJob: - """Returns a completed SqlLoadJob or InsertValuesJob - - Returns completed jobs as SqlLoadJob and InsertValuesJob executed atomically in start_file_load so any jobs that should be recreated are already completed. - Obviously the case of asking for jobs that were never created will not be handled. With correctly implemented loader that cannot happen. - - Args: - file_path (str): a path to a job file - - Returns: - LoadJob: Always a restored job completed - """ - job = super().restore_file_load(file_path) - if not job: - job = EmptyLoadJob.from_file_path(file_path, "completed") - return job - - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: # this is using sql_client internally and will raise a right exception if file_path.endswith("insert_values"): - job = InsertValuesLoadJob(table["name"], file_path, self.sql_client) + job = InsertValuesLoadJob(file_path) return job diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index dd0e783414..7fdd979c5d 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -42,18 +42,20 @@ WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, - NewLoadJob, + FollowupJob, WithStagingDataset, - TLoadJobState, + RunnableLoadJob, LoadJob, JobClientBase, - FollowupJob, + HasFollowupJobs, CredentialsConfiguration, ) from dlt.destinations.exceptions import DatabaseUndefinedRelation -from dlt.destinations.job_impl import EmptyLoadJobWithoutFollowup, NewReferenceJob -from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob +from dlt.destinations.job_impl import ( + ReferenceFollowupJob, +) +from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlStagingCopyFollowupJob from dlt.destinations.typing import TNativeConn from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.utils import ( @@ -66,36 +68,32 @@ DDL_COMMANDS = ["ALTER", "CREATE", "DROP"] -class SqlLoadJob(LoadJob): +class SqlLoadJob(RunnableLoadJob): """A job executing sql statement, without followup trait""" - def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: "SqlJobClientBase" = None + + def run(self) -> None: + self._sql_client = self._job_client.sql_client # execute immediately if client present - with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro(self._file_path, "r", encoding="utf-8") as f: sql = f.read() # Some clients (e.g. databricks) do not support multiple statements in one execute call - if not sql_client.capabilities.supports_multiple_statements: - sql_client.execute_many(self._split_fragments(sql)) + if not self._sql_client.capabilities.supports_multiple_statements: + self._sql_client.execute_many(self._split_fragments(sql)) # if we detect ddl transactions, only execute transaction if supported by client elif ( not self._string_contains_ddl_queries(sql) - or sql_client.capabilities.supports_ddl_transactions + or self._sql_client.capabilities.supports_ddl_transactions ): # with sql_client.begin_transaction(): - sql_client.execute_sql(sql) + self._sql_client.execute_sql(sql) else: # sql_client.execute_sql(sql) - sql_client.execute_many(self._split_fragments(sql)) - - def state(self) -> TLoadJobState: - # this job is always done - return "completed" - - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() + self._sql_client.execute_many(self._split_fragments(sql)) def _string_contains_ddl_queries(self, sql: str) -> bool: for cmd in DDL_COMMANDS: @@ -111,27 +109,16 @@ def is_sql_job(file_path: str) -> bool: return os.path.splitext(file_path)[1][1:] == "sql" -class CopyRemoteFileLoadJob(LoadJob, FollowupJob): +class CopyRemoteFileLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, - table: TTableSchema, file_path: str, - sql_client: SqlClientBase[Any], staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - self._sql_client = sql_client + super().__init__(file_path) + self._job_client: "SqlJobClientBase" = None self._staging_credentials = staging_credentials - - self.execute(table, NewReferenceJob.resolve_reference(file_path)) - - def execute(self, table: TTableSchema, bucket_path: str) -> None: - # implement in child implementations - raise NotImplementedError() - - def state(self) -> TLoadJobState: - # this job is always done - return "completed" + self._bucket_path = ReferenceFollowupJob.resolve_reference(file_path) class SqlJobClientBase(JobClientBase, WithStateSync): @@ -227,19 +214,23 @@ def should_truncate_table_before_load(self, table: TTableSchema) -> bool: and self.config.replace_strategy == "truncate-and-insert" ) - def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_append_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJob]: return [] - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - return [SqlMergeJob.from_table_chain(table_chain, self.sql_client)] + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + return [SqlMergeFollowupJob.from_table_chain(table_chain, self.sql_client)] def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - jobs: List[NewLoadJob] = [] + ) -> List[FollowupJob]: + jobs: List[FollowupJob] = [] if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: jobs.append( - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) + SqlStagingCopyFollowupJob.from_table_chain( + table_chain, self.sql_client, {"replace": True} + ) ) return jobs @@ -247,7 +238,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs @@ -261,28 +252,13 @@ def create_table_chain_completed_followup_jobs( jobs.extend(self._create_replace_followup_jobs(table_chain)) return jobs - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" - self._set_query_tags_for_job(load_id, table) if SqlLoadJob.is_sql_job(file_path): - # execute sql load job - return SqlLoadJob(file_path, self.sql_client) - return None - - def restore_file_load(self, file_path: str) -> LoadJob: - """Returns a completed SqlLoadJob or None to let derived classes to handle their specific jobs - - Returns completed jobs as SqlLoadJob is executed atomically in start_file_load so any jobs that should be recreated are already completed. - Obviously the case of asking for jobs that were never created will not be handled. With correctly implemented loader that cannot happen. - - Args: - file_path (str): a path to a job file - - Returns: - LoadJob: A restored job or none - """ - if SqlLoadJob.is_sql_job(file_path): - return EmptyLoadJobWithoutFollowup.from_file_path(file_path, "completed") + # create sql load job + return SqlLoadJob(file_path) return None def complete_load(self, load_id: str) -> None: @@ -678,6 +654,9 @@ def _verify_schema(self) -> None: logger.error(str(exception)) raise exceptions[0] + def prepare_load_job_execution(self, job: RunnableLoadJob) -> None: + self._set_query_tags_for_job(load_id=job._load_id, table=job._load_table) + def _set_query_tags_for_job(self, load_id: str, table: TTableSchema) -> None: """Sets query tags in sql_client for a job in package `load_id`, starting for a particular `table`""" from dlt.common.pipeline import current_pipeline diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 9a8f7277b7..41c939f482 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -1,14 +1,22 @@ from abc import ABC, abstractmethod import os import tempfile # noqa: 251 -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List from dlt.common.json import json -from dlt.common.destination.reference import NewLoadJob, FollowupJob, TLoadJobState, LoadJob +from dlt.common.destination.reference import ( + HasFollowupJobs, + TLoadJobState, + RunnableLoadJob, + JobClientBase, + FollowupJob, + LoadJob, +) from dlt.common.storages.load_package import commit_load_package_state from dlt.common.schema import Schema, TTableSchema from dlt.common.storages import FileStorage from dlt.common.typing import TDataItems +from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.destinations.impl.destination.configuration import ( CustomDestinationClientConfiguration, @@ -16,17 +24,26 @@ ) -class EmptyLoadJobWithoutFollowup(LoadJob): - def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) -> None: +class FinalizedLoadJob(LoadJob): + """ + Special Load Job that should never get started and just indicates a job being in a final state. + May also be used to indicate that nothing needs to be done. + """ + + def __init__( + self, file_path: str, status: TLoadJobState = "completed", exception: str = None + ) -> None: self._status = status self._exception = exception - super().__init__(file_name) + self._file_path = file_path + assert self._status in ("completed", "failed", "retry") + super().__init__(file_path) @classmethod def from_file_path( - cls, file_path: str, status: TLoadJobState, message: str = None - ) -> "EmptyLoadJobWithoutFollowup": - return cls(FileStorage.get_file_name_from_file_path(file_path), status, exception=message) + cls, file_path: str, status: TLoadJobState = "completed", message: str = None + ) -> "FinalizedLoadJob": + return cls(file_path, status, exception=message) def state(self) -> TLoadJobState: return self._status @@ -35,101 +52,107 @@ def exception(self) -> str: return self._exception -class EmptyLoadJob(EmptyLoadJobWithoutFollowup, FollowupJob): +class FinalizedLoadJobWithFollowupJobs(FinalizedLoadJob, HasFollowupJobs): pass -class NewLoadJobImpl(EmptyLoadJobWithoutFollowup, NewLoadJob): +class FollowupJobImpl(FollowupJob): + """ + Class to create a new loadjob, not stateful and not runnable + """ + + def __init__(self, file_name: str) -> None: + self._file_path = os.path.join(tempfile.gettempdir(), file_name) + self._parsed_file_name = ParsedLoadJobFileName.parse(file_name) + # we only accept jobs that we can scheduleas new or mark as failed.. + def _save_text_file(self, data: str) -> None: - temp_file = os.path.join(tempfile.gettempdir(), self._file_name) - with open(temp_file, "w", encoding="utf-8") as f: + with open(self._file_path, "w", encoding="utf-8") as f: f.write(data) - self._new_file_path = temp_file def new_file_path(self) -> str: """Path to a newly created temporary job file""" - return self._new_file_path + return self._file_path + def job_id(self) -> str: + """The job id that is derived from the file name and does not changes during job lifecycle""" + return self._parsed_file_name.job_id() -class NewReferenceJob(NewLoadJobImpl): - def __init__( - self, file_name: str, status: TLoadJobState, exception: str = None, remote_path: str = None - ) -> None: - file_name = os.path.splitext(file_name)[0] + ".reference" - super().__init__(file_name, status, exception) - self._remote_path = remote_path - self._save_text_file(remote_path) + +class ReferenceFollowupJob(FollowupJobImpl): + def __init__(self, original_file_name: str, remote_paths: List[str]) -> None: + file_name = os.path.splitext(original_file_name)[0] + "." + "reference" + self._remote_paths = remote_paths + super().__init__(file_name) + self._save_text_file("\n".join(remote_paths)) @staticmethod def is_reference_job(file_path: str) -> bool: return os.path.splitext(file_path)[1][1:] == "reference" @staticmethod - def resolve_reference(file_path: str) -> str: + def resolve_references(file_path: str) -> List[str]: with open(file_path, "r+", encoding="utf-8") as f: # Reading from a file - return f.read() + return f.read().split("\n") + + @staticmethod + def resolve_reference(file_path: str) -> str: + refs = ReferenceFollowupJob.resolve_references(file_path) + assert len(refs) == 1 + return refs[0] -class DestinationLoadJob(LoadJob, ABC): +class DestinationLoadJob(RunnableLoadJob, ABC): def __init__( self, - table: TTableSchema, file_path: str, config: CustomDestinationClientConfiguration, - schema: Schema, destination_state: Dict[str, int], destination_callable: TDestinationCallable, skipped_columns: List[str], + callable_requires_job_client_args: bool = False, ) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - self._file_path = file_path + super().__init__(file_path) self._config = config - self._table = table - self._schema = schema - # we create pre_resolved callable here self._callable = destination_callable - self._state: TLoadJobState = "running" self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" - self.skipped_columns = skipped_columns + self._skipped_columns = skipped_columns + self._destination_state = destination_state + self._callable_requires_job_client_args = callable_requires_job_client_args + + def run(self) -> None: + # update filepath, it will be in running jobs now try: if self._config.batch_size == 0: # on batch size zero we only call the callable with the filename self.call_callable_with_items(self._file_path) else: - current_index = destination_state.get(self._storage_id, 0) - for batch in self.run(current_index): + current_index = self._destination_state.get(self._storage_id, 0) + for batch in self.get_batches(current_index): self.call_callable_with_items(batch) current_index += len(batch) - destination_state[self._storage_id] = current_index - - self._state = "completed" - except Exception as e: - self._state = "retry" - raise e + self._destination_state[self._storage_id] = current_index finally: # save progress commit_load_package_state() - @abstractmethod - def run(self, start_index: int) -> Iterable[TDataItems]: - pass - def call_callable_with_items(self, items: TDataItems) -> None: if not items: return # call callable - self._callable(items, self._table) - - def state(self) -> TLoadJobState: - return self._state + if self._callable_requires_job_client_args: + self._callable(items, self._load_table, job_client=self._job_client) # type: ignore + else: + self._callable(items, self._load_table) - def exception(self) -> str: - raise NotImplementedError() + @abstractmethod + def get_batches(self, start_index: int) -> Iterable[TDataItems]: + pass class DestinationParquetLoadJob(DestinationLoadJob): - def run(self, start_index: int) -> Iterable[TDataItems]: + def get_batches(self, start_index: int) -> Iterable[TDataItems]: # stream items from dlt.common.libs.pyarrow import pyarrow @@ -140,7 +163,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: # on record batches we cannot drop columns, we need to # select the ones we want to keep - keep_columns = list(self._table["columns"].keys()) + keep_columns = list(self._load_table["columns"].keys()) start_batch = start_index / self._config.batch_size with pyarrow.parquet.ParquetFile(self._file_path) as reader: for record_batch in reader.iter_batches( @@ -153,7 +176,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: class DestinationJsonlLoadJob(DestinationLoadJob): - def run(self, start_index: int) -> Iterable[TDataItems]: + def get_batches(self, start_index: int) -> Iterable[TDataItems]: current_batch: TDataItems = [] # stream items @@ -168,7 +191,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: start_index -= 1 continue # skip internal columns - for column in self.skipped_columns: + for column in self._skipped_columns: item.pop(column, None) current_batch.append(item) if len(current_batch) == self._config.batch_size: diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index e67be049ab..a1e38a2c20 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional, Callable, Union import yaml -from dlt.common.logger import pretty_format_exception +from dlt.common.time import ensure_pendulum_datetime from dlt.common.schema.typing import ( TTableSchema, @@ -21,8 +21,9 @@ from dlt.common.utils import uniq_id from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.destinations.exceptions import MergeDispositionException -from dlt.destinations.job_impl import NewLoadJobImpl +from dlt.destinations.job_impl import FollowupJobImpl from dlt.destinations.sql_client import SqlClientBase +from dlt.common.destination.exceptions import DestinationTransientException class SqlJobParams(TypedDict, total=False): @@ -33,10 +34,19 @@ class SqlJobParams(TypedDict, total=False): DEFAULTS: SqlJobParams = {"replace": False} -class SqlBaseJob(NewLoadJobImpl): - """Sql base job for jobs that rely on the whole tablechain""" +class SqlJobCreationException(DestinationTransientException): + def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSchema]) -> None: + tables_str = yaml.dump( + table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False + ) + super().__init__( + f"Could not create SQLFollowupJob with exception {str(original_exception)}. Table" + f" chain: {tables_str}" + ) - failed_text: str = "" + +class SqlFollowupJob(FollowupJobImpl): + """Sql base job for jobs that rely on the whole tablechain""" @classmethod def from_table_chain( @@ -44,7 +54,7 @@ def from_table_chain( table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, - ) -> NewLoadJobImpl: + ) -> FollowupJobImpl: """Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). @@ -54,6 +64,7 @@ def from_table_chain( file_info = ParsedLoadJobFileName( top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "sql" ) + try: # Remove line breaks from multiline statements and write one SQL statement per line in output file # to support clients that need to execute one statement at a time (i.e. snowflake) @@ -61,15 +72,12 @@ def from_table_chain( " ".join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client, params) ] - job = cls(file_info.file_name(), "running") + job = cls(file_info.file_name()) job._save_text_file("\n".join(sql)) - except Exception: - # return failed job - tables_str = yaml.dump( - table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False - ) - job = cls(file_info.file_name(), "failed", pretty_format_exception()) - job._save_text_file("\n".join([cls.failed_text, tables_str])) + except Exception as e: + # raise exception with some context + raise SqlJobCreationException(e, table_chain) from e + return job @classmethod @@ -82,11 +90,9 @@ def generate_sql( pass -class SqlStagingCopyJob(SqlBaseJob): +class SqlStagingCopyFollowupJob(SqlFollowupJob): """Generates a list of sql statements that copy the data from staging dataset into destination dataset.""" - failed_text: str = "Tried to generate a staging copy sql job for the following tables:" - @classmethod def _generate_clone_sql( cls, @@ -141,14 +147,12 @@ def generate_sql( return cls._generate_insert_sql(table_chain, sql_client, params) -class SqlMergeJob(SqlBaseJob): +class SqlMergeFollowupJob(SqlFollowupJob): """ Generates a list of sql statements that merge the data from staging dataset into destination dataset. If no merge keys are discovered, falls back to append. """ - failed_text: str = "Tried to generate a merge sql job for the following tables:" - @classmethod def generate_sql( # type: ignore[return] cls, @@ -717,10 +721,18 @@ def gen_scd2_sql( format_datetime_literal = ( DestinationCapabilitiesContext.generic_capabilities().format_datetime_literal ) - boundary_ts = format_datetime_literal( - current_load_package()["state"]["created_at"], + + boundary_ts = ensure_pendulum_datetime( + root_table.get( # type: ignore[arg-type] + "x-boundary-timestamp", + current_load_package()["state"]["created_at"], + ) + ) + boundary_literal = format_datetime_literal( + boundary_ts, caps.timestamp_precision, ) + active_record_timestamp = get_active_record_timestamp(root_table) if active_record_timestamp is None: active_record_literal = "NULL" @@ -733,7 +745,7 @@ def gen_scd2_sql( # retire updated and deleted records sql.append(f""" - {cls.gen_update_table_prefix(root_table_name)} {to} = {boundary_ts} + {cls.gen_update_table_prefix(root_table_name)} {to} = {boundary_literal} WHERE {is_active_clause} AND {hash_} NOT IN (SELECT {hash_} FROM {staging_root_table_name}); """) @@ -743,22 +755,22 @@ def gen_scd2_sql( col_str = ", ".join([c for c in columns if c not in (from_, to)]) sql.append(f""" INSERT INTO {root_table_name} ({col_str}, {from_}, {to}) - SELECT {col_str}, {boundary_ts} AS {from_}, {active_record_literal} AS {to} + SELECT {col_str}, {boundary_literal} AS {from_}, {active_record_literal} AS {to} FROM {staging_root_table_name} AS s - WHERE {hash_} NOT IN (SELECT {hash_} FROM {root_table_name}); + WHERE {hash_} NOT IN (SELECT {hash_} FROM {root_table_name} WHERE {is_active_clause}); """) # insert list elements for new active records in child tables child_tables = table_chain[1:] if child_tables: - unique_column = escape_column_id( - cls._get_unique_col(table_chain, sql_client, root_table) - ) # TODO: - based on deterministic child hashes (OK) # - if row hash changes all is right # - if it does not we only capture new records, while we should replace existing with those in stage # - this write disposition is way more similar to regular merge (how root tables are handled is different, other tables handled same) for table in child_tables: + unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, table) + ) table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"]) sql.append(f""" INSERT INTO {table_name} diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index dce375afb0..67a6b3e83a 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -26,6 +26,7 @@ new_table, ) from dlt.common.typing import TDataItem +from dlt.common.time import ensure_pendulum_datetime from dlt.common.utils import clone_dict_nested from dlt.common.normalizers.json.relational import DataItemNormalizer from dlt.common.validation import validate_dict_ignoring_xkeys @@ -444,6 +445,8 @@ def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: mddict: TMergeDispositionDict = deepcopy(dict_["write_disposition"]) if mddict is not None: dict_["x-merge-strategy"] = mddict.get("strategy", DEFAULT_MERGE_STRATEGY) + if "boundary_timestamp" in mddict: + dict_["x-boundary-timestamp"] = mddict["boundary_timestamp"] # add columns for `scd2` merge strategy if dict_.get("x-merge-strategy") == "scd2": if mddict.get("validity_column_names") is None: @@ -465,11 +468,16 @@ def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: "x-valid-to": True, "x-active-record-timestamp": mddict.get("active_record_timestamp"), } + # unique constraint is dropped for C_DLT_ID when used to store + # SCD2 row hash (only applies to root table) hash_ = mddict.get("row_version_column_name", DataItemNormalizer.C_DLT_ID) dict_["columns"][hash_] = { "name": hash_, "nullable": False, "x-row-version": True, + # duplicate value in row hash column is possible in case + # of insert-delete-reinsert pattern + "unique": False, } @staticmethod @@ -507,3 +515,14 @@ def validate_write_disposition_hint(wd: TTableHintTemplate[TWriteDispositionConf f'`{wd["strategy"]}` is not a valid merge strategy. ' f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""" ) + + for ts in ("active_record_timestamp", "boundary_timestamp"): + if ts == "active_record_timestamp" and wd.get("active_record_timestamp") is None: + continue # None is allowed for active_record_timestamp + if ts in wd: + try: + ensure_pendulum_datetime(wd[ts]) # type: ignore[literal-required] + except Exception: + raise ValueError( + f'could not parse `{ts}` value "{wd[ts]}"' # type: ignore[literal-required] + ) diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index 266581c785..aa1c60901e 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -156,11 +156,13 @@ def _run_dbt_command( i = iter_stdout_with_result(self.venv, "python", "-c", script) while True: sys.stdout.write(next(i).strip()) + sys.stdout.write("\n") except StopIteration as si: # return result from generator return si.value # type: ignore except CalledProcessError as cpe: sys.stderr.write(cpe.stderr) + sys.stdout.write("\n") raise def run( diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index e85dffd2e9..14d0eb1b23 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -5,7 +5,12 @@ ) -class LoadClientJobFailed(DestinationTerminalException): +class LoadClientJobException(Exception): + load_id: str + job_id: str + + +class LoadClientJobFailed(DestinationTerminalException, LoadClientJobException): def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: self.load_id = load_id self.job_id = job_id @@ -16,15 +21,19 @@ def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: ) -class LoadClientJobRetry(DestinationTransientException): - def __init__(self, load_id: str, job_id: str, retry_count: int, max_retry_count: int) -> None: +class LoadClientJobRetry(DestinationTransientException, LoadClientJobException): + def __init__( + self, load_id: str, job_id: str, retry_count: int, max_retry_count: int, retry_message: str + ) -> None: self.load_id = load_id self.job_id = job_id self.retry_count = retry_count self.max_retry_count = max_retry_count + self.retry_message = retry_message super().__init__( f"Job for {job_id} had {retry_count} retries which a multiple of {max_retry_count}." " Exiting retry loop. You can still rerun the load package to retry this job." + f" Last failure message was {retry_message}" ) @@ -50,3 +59,18 @@ def __init__(self, table_name: str, write_disposition: str, file_name: str) -> N f"Loader does not support {write_disposition} in table {table_name} when loading file" f" {file_name}" ) + + +class FollowupJobCreationFailedException(DestinationTransientException): + def __init__(self, job_id: str) -> None: + self.job_id = job_id + super().__init__(f"Failed to create followup job for job with id {job_id}") + + +class TableChainFollowupJobCreationFailedException(DestinationTransientException): + def __init__(self, root_table_name: str) -> None: + self.root_table_name = root_table_name + super().__init__( + "Failed creating table chain followup jobs for table chain with root table" + f" {root_table_name}." + ) diff --git a/dlt/load/load.py b/dlt/load/load.py index 2290d40a1e..99a12d69ee 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -18,18 +18,18 @@ from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.logger import pretty_format_exception -from dlt.common.exceptions import TerminalValueError from dlt.common.configuration.container import Container from dlt.common.schema import Schema from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, - FollowupJob, + HasFollowupJobs, JobClientBase, WithStagingDataset, Destination, + RunnableLoadJob, LoadJob, - NewLoadJob, + FollowupJob, TLoadJobState, DestinationClientConfiguration, SupportsStagingDestination, @@ -37,10 +37,10 @@ ) from dlt.common.destination.exceptions import ( DestinationTerminalException, - DestinationTransientException, ) +from dlt.common.runtime import signals -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.load.configuration import LoaderConfiguration from dlt.load.exceptions import ( @@ -48,12 +48,16 @@ LoadClientJobRetry, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, + LoadClientJobException, + FollowupJobCreationFailedException, + TableChainFollowupJobCreationFailedException, ) from dlt.load.utils import ( _extend_tables_with_table_chain, get_completed_table_chain, init_client, filter_new_jobs, + get_available_worker_slots, ) @@ -80,6 +84,9 @@ def __init__( self.pool = NullExecutor() self.load_storage: LoadStorage = self.create_storage(is_storage_owner) self._loaded_packages: List[LoadPackageInfo] = [] + self._run_loop_sleep_duration: float = ( + 1.0 # amount of time to sleep between querying completed jobs + ) super().__init__() def create_storage(self, is_storage_owner: bool) -> LoadStorage: @@ -108,10 +115,13 @@ def get_staging_destination_client(self, schema: Schema) -> JobClientBase: return self.staging_destination.client(schema, self.initial_staging_client_config) def is_staging_destination_job(self, file_path: str) -> bool: + file_type = os.path.splitext(file_path)[1][1:] + # for now we know that reference jobs always go do the main destination + if file_type == "reference": + return False return ( self.staging_destination is not None - and os.path.splitext(file_path)[1][1:] - in self.staging_destination.capabilities().supported_loader_file_formats + and file_type in self.staging_destination.capabilities().supported_loader_file_formats ) @contextlib.contextmanager @@ -125,94 +135,150 @@ def maybe_with_staging_dataset( else: yield - @staticmethod - @workermethod - def w_spool_job( - self: "Load", file_path: str, load_id: str, schema: Schema - ) -> Optional[LoadJob]: + def submit_job( + self, file_path: str, load_id: str, schema: Schema, restore: bool = False + ) -> LoadJob: job: LoadJob = None + + is_staging_destination_job = self.is_staging_destination_job(file_path) + job_client = self.get_destination_client(schema) + + # if we have a staging destination and the file is not a reference, send to staging + active_job_client = ( + self.get_staging_destination_client(schema) + if is_staging_destination_job + else job_client + ) + try: - is_staging_destination_job = self.is_staging_destination_job(file_path) - job_client = self.get_destination_client(schema) - - # if we have a staging destination and the file is not a reference, send to staging - with ( - self.get_staging_destination_client(schema) - if is_staging_destination_job - else job_client - ) as client: - job_info = ParsedLoadJobFileName.parse(file_path) - if job_info.file_format not in self.load_storage.supported_job_file_formats: - raise LoadClientUnsupportedFileFormats( - job_info.file_format, - self.destination.capabilities().supported_loader_file_formats, - file_path, - ) - logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - table = client.prepare_load_table(job_info.table_name) - if table["write_disposition"] not in ["append", "replace", "merge"]: - raise LoadClientUnsupportedWriteDisposition( - job_info.table_name, table["write_disposition"], file_path - ) + # check file format + job_info = ParsedLoadJobFileName.parse(file_path) + if job_info.file_format not in self.load_storage.supported_job_file_formats: + raise LoadClientUnsupportedFileFormats( + job_info.file_format, + self.destination.capabilities().supported_loader_file_formats, + file_path, + ) + logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - if is_staging_destination_job: - use_staging_dataset = isinstance( - job_client, SupportsStagingDestination - ) and job_client.should_load_data_to_staging_dataset_on_staging_destination( - table - ) - else: - use_staging_dataset = isinstance( - job_client, WithStagingDataset - ) and job_client.should_load_data_to_staging_dataset(table) - - with self.maybe_with_staging_dataset(client, use_staging_dataset): - job = client.start_file_load( - table, - self.load_storage.normalized_packages.storage.make_full_path(file_path), - load_id, - ) - except (DestinationTerminalException, TerminalValueError): - # if job irreversibly cannot be started, mark it as failed - logger.exception(f"Terminal problem when adding job {file_path}") - job = EmptyLoadJob.from_file_path(file_path, "failed", pretty_format_exception()) - except (DestinationTransientException, Exception): - # return no job so file stays in new jobs (root) folder - logger.exception(f"Temporary problem when adding job {file_path}") - job = EmptyLoadJob.from_file_path(file_path, "retry", pretty_format_exception()) - if job is None: - raise DestinationTerminalException( - f"Destination could not create a job for file {file_path}. Typically the file" - " extension could not be associated with job type and that indicates an error in" - " the code." + # check write disposition + load_table = active_job_client.prepare_load_table(job_info.table_name) + if load_table["write_disposition"] not in ["append", "replace", "merge"]: + raise LoadClientUnsupportedWriteDisposition( + job_info.table_name, load_table["write_disposition"], file_path + ) + + job = active_job_client.create_load_job( + load_table, + self.load_storage.normalized_packages.storage.make_full_path(file_path), + load_id, + restore=restore, ) - self.load_storage.normalized_packages.start_job(load_id, job.file_name()) + + if job is None: + raise DestinationTerminalException( + f"Destination could not create a job for file {file_path}. Typically the file" + " extension could not be associated with job type and that indicates an error" + " in the code." + ) + except DestinationTerminalException: + job = FinalizedLoadJobWithFollowupJobs.from_file_path( + file_path, "failed", pretty_format_exception() + ) + except Exception: + job = FinalizedLoadJobWithFollowupJobs.from_file_path( + file_path, "retry", pretty_format_exception() + ) + + # move to started jobs in case this is not a restored job + if not restore: + job._file_path = self.load_storage.normalized_packages.start_job( + load_id, job.file_name() + ) + + # only start a thread if this job is runnable + if isinstance(job, RunnableLoadJob): + # determine which dataset to use + if is_staging_destination_job: + use_staging_dataset = isinstance( + job_client, SupportsStagingDestination + ) and job_client.should_load_data_to_staging_dataset_on_staging_destination( + load_table + ) + else: + use_staging_dataset = isinstance( + job_client, WithStagingDataset + ) and job_client.should_load_data_to_staging_dataset(load_table) + + # set job vars + job.set_run_vars(load_id=load_id, schema=schema, load_table=load_table) + + # submit to pool + self.pool.submit(Load.w_run_job, *(id(self), job, is_staging_destination_job, use_staging_dataset, schema)) # type: ignore + + # sanity check: otherwise a job in an actionable state is expected + else: + assert job.state() in ("completed", "failed", "retry") + return job - def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJob]]: - # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs + @staticmethod + @workermethod + def w_run_job( + self: "Load", + job: RunnableLoadJob, + use_staging_client: bool, + use_staging_dataset: bool, + schema: Schema, + ) -> None: + """ + Start a load job in a separate thread + """ + active_job_client = ( + self.get_staging_destination_client(schema) + if use_staging_client + else self.get_destination_client(schema) + ) + with active_job_client as client: + with self.maybe_with_staging_dataset(client, use_staging_dataset): + job.run_managed(active_job_client) + + def start_new_jobs( + self, load_id: str, schema: Schema, running_jobs: Sequence[LoadJob] + ) -> Sequence[LoadJob]: + """ + will retrieve jobs from the new_jobs folder and start as many as there are slots available + """ + caps = self.destination.capabilities( + self.destination.configuration(self.initial_client_config) + ) + + # early exit if no slots available + available_slots = get_available_worker_slots(self.config, caps, running_jobs) + if available_slots <= 0: + return [] + + # get a list of jobs eligible to be started load_files = filter_new_jobs( self.load_storage.list_new_jobs(load_id), - self.destination.capabilities( - self.destination.configuration(self.initial_client_config) - ), + caps, self.config, + running_jobs, + available_slots, ) - file_count = len(load_files) - if file_count == 0: - logger.info(f"No new jobs found in {load_id}") - return 0, [] - logger.info(f"Will load {file_count}, creating jobs") - param_chunk = [(id(self), file, load_id, schema) for file in load_files] - # exceptions should not be raised, None as job is a temporary failure - # other jobs should not be affected - jobs = self.pool.map(Load.w_spool_job, *zip(*param_chunk)) - # remove None jobs and check the rest - return file_count, [job for job in jobs if job is not None] - - def retrieve_jobs( - self, client: JobClientBase, load_id: str, staging_client: JobClientBase = None - ) -> Tuple[int, List[LoadJob]]: + + logger.info(f"Will load additional {len(load_files)}, creating jobs") + started_jobs: List[LoadJob] = [] + for file in load_files: + job = self.submit_job(file, load_id, schema) + started_jobs.append(job) + + return started_jobs + + def resume_started_jobs(self, load_id: str, schema: Schema) -> List[LoadJob]: + """ + will check jobs in the started folder and resume them + """ jobs: List[LoadJob] = [] # list all files that were started but not yet completed @@ -220,23 +286,13 @@ def retrieve_jobs( logger.info(f"Found {len(started_jobs)} that are already started and should be continued") if len(started_jobs) == 0: - return 0, jobs + return jobs for file_path in started_jobs: - try: - logger.info(f"Will retrieve {file_path}") - client = staging_client if self.is_staging_destination_job(file_path) else client - job = client.restore_file_load(file_path) - except DestinationTerminalException: - logger.exception(f"Job retrieval for {file_path} failed, job will be terminated") - job = EmptyLoadJob.from_file_path(file_path, "failed", pretty_format_exception()) - # proceed to appending job, do not reraise - except (DestinationTransientException, Exception): - # raise on all temporary exceptions, typically network / server problems - raise + job = self.submit_job(file_path, load_id, schema, restore=True) jobs.append(job) - return len(jobs), jobs + return jobs def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: return [ @@ -246,9 +302,14 @@ def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: def create_followup_jobs( self, load_id: str, state: TLoadJobState, starting_job: LoadJob, schema: Schema - ) -> List[NewLoadJob]: - jobs: List[NewLoadJob] = [] - if isinstance(starting_job, FollowupJob): + ) -> None: + """ + for jobs marked as having followup jobs, find them all and store them to the new jobs folder + where they will be picked up for execution + """ + + jobs: List[FollowupJob] = [] + if isinstance(starting_job, HasFollowupJobs): # check for merge jobs only for jobs executing on the destination, the staging destination jobs must be excluded # NOTE: we may move that logic to the interface starting_job_file_name = starting_job.file_name() @@ -257,7 +318,7 @@ def create_followup_jobs( top_job_table = get_top_level_table( schema.tables, starting_job.job_file_info().table_name ) - # if all tables of chain completed, create follow up jobs + # if all tables of chain completed, create follow up jobs all_jobs_states = self.load_storage.normalized_packages.list_all_jobs_with_states( load_id ) @@ -265,60 +326,71 @@ def create_followup_jobs( schema, all_jobs_states, top_job_table, starting_job.job_file_info().job_id() ): table_chain_names = [table["name"] for table in table_chain] - # create job infos that contain full path to job table_chain_jobs = [ - self.load_storage.normalized_packages.job_to_job_info(load_id, *job_state) + # we mark all jobs as completed, as by the time the followup job runs the starting job will be in this + # folder too + self.load_storage.normalized_packages.job_to_job_info( + load_id, "completed_jobs", job_state[1] + ) for job_state in all_jobs_states if job_state[1].table_name in table_chain_names # job being completed is still in started_jobs and job_state[0] in ("completed_jobs", "started_jobs") ] - if follow_up_jobs := client.create_table_chain_completed_followup_jobs( - table_chain, table_chain_jobs - ): - jobs = jobs + follow_up_jobs - jobs = jobs + starting_job.create_followup_jobs(state) - return jobs + try: + if follow_up_jobs := client.create_table_chain_completed_followup_jobs( + table_chain, table_chain_jobs + ): + jobs = jobs + follow_up_jobs + except Exception as e: + raise TableChainFollowupJobCreationFailedException( + root_table_name=table_chain[0]["name"] + ) from e - def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> List[LoadJob]: + try: + jobs = jobs + starting_job.create_followup_jobs(state) + except Exception as e: + raise FollowupJobCreationFailedException(job_id=starting_job.job_id()) from e + + # import all followup jobs to the new jobs folder + for followup_job in jobs: + # save all created jobs + self.load_storage.normalized_packages.import_job( + load_id, followup_job.new_file_path(), job_state="new_jobs" + ) + logger.info( + f"Job {starting_job.job_id()} CREATED a new FOLLOWUP JOB" + f" {followup_job.new_file_path()} placed in new_jobs" + ) + + def complete_jobs( + self, load_id: str, jobs: Sequence[LoadJob], schema: Schema + ) -> Tuple[List[LoadJob], List[LoadJob], Optional[LoadClientJobException]]: """Run periodically in the main thread to collect job execution statuses. After detecting change of status, it commits the job state by moving it to the right folder May create one or more followup jobs that get scheduled as new jobs. New jobs are created only in terminal states (completed / failed) """ + # list of jobs still running remaining_jobs: List[LoadJob] = [] - - def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: - for followup_job in followup_jobs: - # running should be moved into "new jobs", other statuses into started - folder: TJobState = ( - "new_jobs" if followup_job.state() == "running" else "started_jobs" - ) - # save all created jobs - self.load_storage.normalized_packages.import_job( - load_id, followup_job.new_file_path(), job_state=folder - ) - logger.info( - f"Job {job.job_id()} CREATED a new FOLLOWUP JOB" - f" {followup_job.new_file_path()} placed in {folder}" - ) - # if followup job is not "running" place it in current queue to be finalized - if not followup_job.state() == "running": - remaining_jobs.append(followup_job) + # list of jobs in final state + finalized_jobs: List[LoadJob] = [] + # if an exception condition was met, return it to the main runner + pending_exception: Optional[LoadClientJobException] = None logger.info(f"Will complete {len(jobs)} for {load_id}") for ii in range(len(jobs)): job = jobs[ii] logger.debug(f"Checking state for job {job.job_id()}") state: TLoadJobState = job.state() - if state == "running": + if state in ("ready", "running"): # ask again logger.debug(f"job {job.job_id()} still running") remaining_jobs.append(job) elif state == "failed": # create followup jobs - _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + self.create_followup_jobs(load_id, state, job, schema) # try to get exception message from job failed_message = job.exception() @@ -329,6 +401,14 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: f"Job for {job.job_id()} failed terminally in load {load_id} with message" f" {failed_message}" ) + # schedule exception on job failure + if self.config.raise_on_failed_jobs: + pending_exception = LoadClientJobFailed( + load_id, + job.job_file_info().job_id(), + failed_message, + ) + finalized_jobs.append(job) elif state == "retry": # try to get exception message from job retry_message = job.exception() @@ -337,13 +417,27 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: logger.warning( f"Job for {job.job_id()} retried in load {load_id} with message {retry_message}" ) + # possibly schedule exception on too many retries + if self.config.raise_on_max_retries: + r_c = job.job_file_info().retry_count + 1 + if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: + pending_exception = LoadClientJobRetry( + load_id, + job.job_file_info().job_id(), + r_c, + self.config.raise_on_max_retries, + retry_message=retry_message, + ) elif state == "completed": # create followup jobs - _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + self.create_followup_jobs(load_id, state, job, schema) # move to completed folder after followup jobs are created # in case of exception when creating followup job, the loader will retry operation and try to complete again self.load_storage.normalized_packages.complete_job(load_id, job.file_name()) logger.info(f"Job for {job.job_id()} completed in load {load_id}") + finalized_jobs.append(job) + else: + raise Exception("Incorrect job state") if state in ["failed", "completed"]: self.collector.update("Jobs") @@ -352,7 +446,7 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: "Jobs", 1, message="WARNING: Some of the jobs failed!", label="Failed" ) - return remaining_jobs + return remaining_jobs, finalized_jobs, pending_exception def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) -> None: # do not commit load id for aborted packages @@ -377,6 +471,18 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) + def update_load_package_info(self, load_id: str) -> None: + # update counter we only care about the jobs that are scheduled to be loaded + package_jobs = self.load_storage.normalized_packages.get_load_package_jobs(load_id) + total_jobs = reduce(lambda p, c: p + len(c), package_jobs.values(), 0) + no_failed_jobs = len(package_jobs["failed_jobs"]) + no_completed_jobs = len(package_jobs["completed_jobs"]) + no_failed_jobs + self.collector.update("Jobs", no_completed_jobs, total_jobs) + if no_failed_jobs > 0: + self.collector.update( + "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" + ) + def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) @@ -386,6 +492,8 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: dropped_tables = current_load_package()["state"].get("dropped_tables", []) truncated_tables = current_load_package()["state"].get("truncated_tables", []) + self.update_load_package_info(load_id) + # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: @@ -424,74 +532,54 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: drop_tables=dropped_tables, truncate_tables=truncated_tables, ) - self.load_storage.commit_schema_update(load_id, applied_update) - # initialize staging destination and spool or retrieve unfinished jobs - if self.staging_destination: - with self.get_staging_destination_client(schema) as staging_client: - jobs_count, jobs = self.retrieve_jobs(job_client, load_id, staging_client) - else: - jobs_count, jobs = self.retrieve_jobs(job_client, load_id) - - if not jobs: - # jobs count is a total number of jobs including those that could not be initialized - jobs_count, jobs = self.spool_new_jobs(load_id, schema) - # if there are no existing or new jobs we complete the package - if jobs_count == 0: - self.complete_package(load_id, schema, False) - return - # update counter we only care about the jobs that are scheduled to be loaded - package_jobs = self.load_storage.normalized_packages.get_load_package_jobs(load_id) - total_jobs = reduce(lambda p, c: p + len(c), package_jobs.values(), 0) - no_failed_jobs = len(package_jobs["failed_jobs"]) - no_completed_jobs = len(package_jobs["completed_jobs"]) + no_failed_jobs - self.collector.update("Jobs", no_completed_jobs, total_jobs) - if no_failed_jobs > 0: - self.collector.update( - "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" - ) + # collect all unfinished jobs + running_jobs: List[LoadJob] = self.resume_started_jobs(load_id, schema) + # loop until all jobs are processed + pending_exception: Optional[LoadClientJobException] = None while True: try: - remaining_jobs = self.complete_jobs(load_id, jobs, schema) - if len(remaining_jobs) == 0: - # get package status - package_jobs = self.load_storage.normalized_packages.get_load_package_jobs( - load_id - ) - # possibly raise on failed jobs - if self.config.raise_on_failed_jobs: - if package_jobs["failed_jobs"]: - failed_job = package_jobs["failed_jobs"][0] - raise LoadClientJobFailed( - load_id, - failed_job.job_id(), - self.load_storage.normalized_packages.get_job_failed_message( - load_id, failed_job - ), - ) - # possibly raise on too many retries - if self.config.raise_on_max_retries: - for new_job in package_jobs["new_jobs"]: - r_c = new_job.retry_count - if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: - raise LoadClientJobRetry( - load_id, - new_job.job_id(), - r_c, - self.config.raise_on_max_retries, - ) + # we continuously spool new jobs and complete finished ones + running_jobs, finalized_jobs, new_pending_exception = self.complete_jobs( + load_id, running_jobs, schema + ) + pending_exception = pending_exception or new_pending_exception + + # do not spool new jobs if there was a signal or an exception was encountered + # we inform the users how many jobs remain when shutting down, but only if the count of running jobs + # has changed (as determined by finalized jobs) + if signals.signal_received(): + if finalized_jobs: + logger.info( + f"Signal received, draining running jobs. {len(running_jobs)} to go." + ) + elif pending_exception: + if finalized_jobs: + logger.info( + f"Exception for job {pending_exception.job_id} received, draining" + f" running jobs.{len(running_jobs)} to go." + ) + else: + running_jobs += self.start_new_jobs(load_id, schema, running_jobs) + + if len(running_jobs) == 0: + # if a pending exception was discovered during completion of jobs + # we can raise it now + if pending_exception: + raise pending_exception break - # process remaining jobs again - jobs = remaining_jobs # this will raise on signal - sleep(1) + sleep(self._run_loop_sleep_duration) except LoadClientJobFailed: # the package is completed and skipped self.complete_package(load_id, schema, True) raise + # no new jobs, load package done + self.complete_package(load_id, schema, False) + def run(self, pool: Optional[Executor]) -> TRunMetrics: # store pool self.pool = pool or NullExecutor() diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 67a813f5f2..9750f89d4b 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -12,10 +12,7 @@ from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TTableSchema -from dlt.common.destination.reference import ( - JobClientBase, - WithStagingDataset, -) +from dlt.common.destination.reference import JobClientBase, WithStagingDataset, LoadJob from dlt.load.configuration import LoaderConfiguration from dlt.common.destination import DestinationCapabilitiesContext @@ -230,10 +227,30 @@ def _extend_tables_with_table_chain( return result +def get_available_worker_slots( + config: LoaderConfiguration, + capabilities: DestinationCapabilitiesContext, + running_jobs: Sequence[LoadJob], +) -> int: + """ + Returns the number of available worker slots + """ + parallelism_strategy = config.parallelism_strategy or capabilities.loader_parallelism_strategy + + # find real max workers value + max_workers = 1 if parallelism_strategy == "sequential" else config.workers + if mp := capabilities.max_parallel_load_jobs: + max_workers = min(max_workers, mp) + + return max(0, max_workers - len(running_jobs)) + + def filter_new_jobs( file_names: Sequence[str], capabilities: DestinationCapabilitiesContext, config: LoaderConfiguration, + running_jobs: Sequence[LoadJob], + available_slots: int, ) -> Sequence[str]: """Filters the list of new jobs to adhere to max_workers and parallellism strategy""" """NOTE: in the current setup we only filter based on settings for the final destination""" @@ -246,24 +263,27 @@ def filter_new_jobs( # config can overwrite destination settings, if nothing is set, code below defaults to parallel parallelism_strategy = config.parallelism_strategy or capabilities.loader_parallelism_strategy - # find real max workers value - max_workers = 1 if parallelism_strategy == "sequential" else config.workers - if mp := capabilities.max_parallel_load_jobs: - max_workers = min(max_workers, mp) - # regular sequential works on all jobs eligible_jobs = file_names # we must ensure there only is one job per table if parallelism_strategy == "table-sequential": - eligible_jobs = sorted( - eligible_jobs, key=lambda j: ParsedLoadJobFileName.parse(j).table_name - ) - eligible_jobs = [ - next(table_jobs) - for _, table_jobs in groupby( - eligible_jobs, lambda j: ParsedLoadJobFileName.parse(j).table_name - ) - ] + # TODO later: this whole code block is a bit inefficient for long lists of jobs + # better would be to keep a list of loadjobinfos in the loader which we can iterate + + # find table names of all currently running jobs + running_tables = {j._parsed_file_name.table_name for j in running_jobs} + new_jobs: List[str] = [] + + for job in eligible_jobs: + if (table_name := ParsedLoadJobFileName.parse(job).table_name) not in running_tables: + running_tables.add(table_name) + new_jobs.append(job) + # exit loop if we have enough + if len(new_jobs) >= available_slots: + break + + return new_jobs - return eligible_jobs[:max_workers] + else: + return eligible_jobs[:available_slots] diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 98154cd5cf..e80931605c 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -34,6 +34,7 @@ from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV +from dlt.normalize.schema import verify_normalized_schema # normalize worker wrapping function signature @@ -195,6 +196,7 @@ def spool_files( x_normalizer["seen-data"] = True # schema is updated, save it to schema volume if schema.is_modified: + verify_normalized_schema(schema) logger.info( f"Saving schema {schema.name} with version {schema.stored_version}:{schema.version}" ) diff --git a/dlt/normalize/schema.py b/dlt/normalize/schema.py new file mode 100644 index 0000000000..c01d184c92 --- /dev/null +++ b/dlt/normalize/schema.py @@ -0,0 +1,20 @@ +from dlt.common.schema import Schema +from dlt.common.schema.utils import find_incomplete_columns +from dlt.common.schema.exceptions import UnboundColumnException +from dlt.common import logger + + +def verify_normalized_schema(schema: Schema) -> None: + """Verify the schema is valid for next stage after normalization. + + 1. Log warning if any incomplete nullable columns are in any data tables + 2. Raise `UnboundColumnException` on incomplete non-nullable columns (e.g. missing merge/primary key) + """ + for table_name, column, nullable in find_incomplete_columns( + schema.data_tables(seen_data_only=True) + ): + exc = UnboundColumnException(schema.name, table_name, column) + if nullable: + logger.warning(str(exc)) + else: + raise exc diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 73ae064299..c05dabc30c 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -225,7 +225,7 @@ def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: if paginator is None: paginator = self.detect_paginator(response, data) - paginator.update_state(response) + paginator.update_state(response, data) paginator.update_request(request) # yield data with context diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index 4c8ce70bb2..872d4f34e8 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -1,6 +1,6 @@ import warnings from abc import ABC, abstractmethod -from typing import Optional, Dict, Any +from typing import Any, Dict, List, Optional from urllib.parse import urlparse, urljoin from requests import Response, Request @@ -39,7 +39,7 @@ def init_request(self, request: Request) -> None: # noqa: B027, optional overri pass @abstractmethod - def update_state(self, response: Response) -> None: + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: """Updates the paginator's state based on the response from the API. This method should extract necessary pagination details (like next page @@ -73,7 +73,7 @@ def __str__(self) -> str: class SinglePagePaginator(BasePaginator): """A paginator for single-page API responses.""" - def update_state(self, response: Response) -> None: + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: self._has_next_page = False def update_request(self, request: Request) -> None: @@ -96,6 +96,7 @@ def __init__( maximum_value: Optional[int] = None, total_path: Optional[jsonpath.TJsonPath] = None, error_message_items: str = "items", + stop_after_empty_page: Optional[bool] = True, ): """ Args: @@ -116,10 +117,15 @@ def __init__( If not provided, `maximum_value` must be specified. error_message_items (str): The name of the items in the error message. Defaults to 'items'. + stop_after_empty_page (bool): Whether pagination should stop when + a page contains no result items. Defaults to `True`. """ super().__init__() - if total_path is None and maximum_value is None: - raise ValueError("Either `total_path` or `maximum_value` must be provided.") + if total_path is None and maximum_value is None and not stop_after_empty_page: + raise ValueError( + "Either `total_path` or `maximum_value` or `stop_after_empty_page` must be" + " provided." + ) self.param_name = param_name self.current_value = initial_value self.value_step = value_step @@ -127,6 +133,7 @@ def __init__( self.maximum_value = maximum_value self.total_path = jsonpath.compile_path(total_path) if total_path else None self.error_message_items = error_message_items + self.stop_after_empty_page = stop_after_empty_page def init_request(self, request: Request) -> None: if request.params is None: @@ -134,26 +141,32 @@ def init_request(self, request: Request) -> None: request.params[self.param_name] = self.current_value - def update_state(self, response: Response) -> None: - total = None - if self.total_path: - response_json = response.json() - values = jsonpath.find_values(self.total_path, response_json) - total = values[0] if values else None - if total is None: - self._handle_missing_total(response_json) - - try: - total = int(total) - except ValueError: - self._handle_invalid_total(total) - - self.current_value += self.value_step - - if (total is not None and self.current_value >= total + self.base_index) or ( - self.maximum_value is not None and self.current_value >= self.maximum_value - ): + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: + if self._stop_after_this_page(data): self._has_next_page = False + else: + total = None + if self.total_path: + response_json = response.json() + values = jsonpath.find_values(self.total_path, response_json) + total = values[0] if values else None + if total is None: + self._handle_missing_total(response_json) + + try: + total = int(total) + except ValueError: + self._handle_invalid_total(total) + + self.current_value += self.value_step + + if (total is not None and self.current_value >= total + self.base_index) or ( + self.maximum_value is not None and self.current_value >= self.maximum_value + ): + self._has_next_page = False + + def _stop_after_this_page(self, data: Optional[List[Any]] = None) -> bool: + return self.stop_after_empty_page and not data def _handle_missing_total(self, response_json: Dict[str, Any]) -> None: raise ValueError( @@ -229,6 +242,7 @@ def __init__( page_param: str = "page", total_path: jsonpath.TJsonPath = "total", maximum_page: Optional[int] = None, + stop_after_empty_page: Optional[bool] = True, ): """ Args: @@ -246,9 +260,13 @@ def __init__( will stop once this page is reached or exceeded, even if more data is available. This allows you to limit the maximum number of pages for pagination. Defaults to None. + stop_after_empty_page (bool): Whether pagination should stop when + a page contains no result items. Defaults to `True`. """ - if total_path is None and maximum_page is None: - raise ValueError("Either `total_path` or `maximum_page` must be provided.") + if total_path is None and maximum_page is None and not stop_after_empty_page: + raise ValueError( + "Either `total_path` or `maximum_page` or `stop_after_empty_page` must be provided." + ) page = page if page is not None else base_page @@ -260,6 +278,7 @@ def __init__( value_step=1, maximum_value=maximum_page, error_message_items="pages", + stop_after_empty_page=stop_after_empty_page, ) def __str__(self) -> str: @@ -330,6 +349,7 @@ def __init__( limit_param: str = "limit", total_path: jsonpath.TJsonPath = "total", maximum_offset: Optional[int] = None, + stop_after_empty_page: Optional[bool] = True, ) -> None: """ Args: @@ -347,15 +367,21 @@ def __init__( pagination will stop once this offset is reached or exceeded, even if more data is available. This allows you to limit the maximum range for pagination. Defaults to None. + stop_after_empty_page (bool): Whether pagination should stop when + a page contains no result items. Defaults to `True`. """ - if total_path is None and maximum_offset is None: - raise ValueError("Either `total_path` or `maximum_offset` must be provided.") + if total_path is None and maximum_offset is None and not stop_after_empty_page: + raise ValueError( + "Either `total_path` or `maximum_offset` or `stop_after_empty_page` must be" + " provided." + ) super().__init__( param_name=offset_param, initial_value=offset, total_path=total_path, value_step=limit, maximum_value=maximum_offset, + stop_after_empty_page=stop_after_empty_page, ) self.limit_param = limit_param self.limit = limit @@ -484,7 +510,7 @@ def __init__(self, links_next_key: str = "next") -> None: super().__init__() self.links_next_key = links_next_key - def update_state(self, response: Response) -> None: + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: """Extracts the next page URL from the 'Link' header in the response.""" self._next_reference = response.links.get(self.links_next_key, {}).get("url") @@ -539,7 +565,7 @@ def __init__( super().__init__() self.next_url_path = jsonpath.compile_path(next_url_path) - def update_state(self, response: Response) -> None: + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: """Extracts the next page URL from the JSON response.""" values = jsonpath.find_values(self.next_url_path, response.json()) self._next_reference = values[0] if values else None @@ -618,7 +644,7 @@ def __init__( self.cursor_path = jsonpath.compile_path(cursor_path) self.cursor_param = cursor_param - def update_state(self, response: Response) -> None: + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: """Extracts the cursor value from the JSON response.""" values = jsonpath.find_values(self.cursor_path, response.json()) self._next_reference = values[0] if values else None diff --git a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py index ce4b2a12d0..48a16f15c0 100644 --- a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py +++ b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py @@ -8,7 +8,7 @@ In this example, you'll find a Python script that demonstrates how to load to BigQuery with the custom destination. We'll learn how to: -- Use [built-in credentials.](../general-usage/credentials/config_specs#gcp-credentials) +- Use [built-in credentials.](../general-usage/credentials/complex_types#gcp-credentials) - Use the [custom destination.](../dlt-ecosystem/destinations/destination.md) - Use pyarrow tables to create complex column types on BigQuery. - Use BigQuery `autodetect=True` for schema inference from parquet files. diff --git a/docs/examples/google_sheets/google_sheets.py b/docs/examples/google_sheets/google_sheets.py index 82e5d2e76d..fbc0686fb9 100644 --- a/docs/examples/google_sheets/google_sheets.py +++ b/docs/examples/google_sheets/google_sheets.py @@ -8,8 +8,8 @@ In this example, you'll find a Python script that demonstrates how to load Google Sheets data using the `dlt` library. We'll learn how to: -- use [built-in credentials](../general-usage/credentials/config_specs#gcp-credentials); -- use [union of credentials](../general-usage/credentials/config_specs#working-with-alternatives-of-credentials-union-types); +- use [built-in credentials](../general-usage/credentials/complex_types#gcp-credentials); +- use [union of credentials](../general-usage/credentials/complex_types#working-with-alternatives-of-credentials-union-types); - create [dynamically generated resources](../general-usage/source#create-resources-dynamically). :::tip diff --git a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md index 51d124251a..334e08c4a7 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md +++ b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md @@ -112,6 +112,18 @@ VMs available on GCP (cloud functions, Composer runners, Colab notebooks) have a location = "US" ``` +### Using Different `project_id` + +You can set the `project_id` in your configuration to be different from the one in your credentials, provided your account has access to it: +```toml +[destination.bigquery] +project_id = "project_id_destination" + +[destination.bigquery.credentials] +project_id = "project_id_credentials" +``` +In this scenario, `project_id_credentials` will be used for authentication, while `project_id_destination` will be used as the data destination. + ## Write Disposition All write dispositions are supported. diff --git a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md index 9ecd1ae6dc..19cef92f9d 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md @@ -63,7 +63,7 @@ You can configure the following file formats to load data to duckdb: :::note `duckdb` cannot COPY many parquet files to a single table from multiple threads. In this situation, `dlt` serializes the loads. Still, that may be faster than INSERT. ::: -* [jsonl](../file-formats/jsonl.md) **is supported but does not work if JSON fields are optional. The missing keys fail the COPY instead of being interpreted as NULL.** +* [jsonl](../file-formats/jsonl.md) :::tip `duckdb` has [timestamp types](https://duckdb.org/docs/sql/data_types/timestamp.html) with resolutions from milliseconds to nanoseconds. However diff --git a/docs/website/docs/dlt-ecosystem/destinations/filesystem.md b/docs/website/docs/dlt-ecosystem/destinations/filesystem.md index bbe21b7ea7..018b838363 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/filesystem.md +++ b/docs/website/docs/dlt-ecosystem/destinations/filesystem.md @@ -514,6 +514,12 @@ You need the `deltalake` package to use this format: pip install "dlt[deltalake]" ``` +You also need `pyarrow>=17.0.0`: + +```sh +pip install 'pyarrow>=17.0.0' +``` + Set the `table_format` argument to `delta` when defining your resource: ```py @@ -524,6 +530,23 @@ def my_delta_resource(): > `dlt` always uses `parquet` as `loader_file_format` when using the `delta` table format. Any setting of `loader_file_format` is disregarded. +#### Delta table partitioning +A Delta table can be partitioned ([Hive-style partitioning](https://delta.io/blog/pros-cons-hive-style-partionining/)) by specifying one or more `partition` column hints. This example partitions the Delta table by the `foo` column: + +```py +@dlt.resource( + table_format="delta", + columns={"foo": {"partition": True}} +) +def my_delta_resource(): + ... +``` + +:::caution +It is **not** possible to change partition columns after the Delta table has been created. Trying to do so causes an error stating that the partition columns don't match. +::: + + #### Storage options You can pass storage options by configuring `destination.filesystem.deltalake_storage_options`: @@ -536,7 +559,26 @@ deltalake_storage_options = '{"AWS_S3_LOCKING_PROVIDER": "dynamodb", DELTA_DYNAM You don't need to specify credentials here. `dlt` merges the required credentials with the options you provided, before passing it as `storage_options`. ->❗When using `s3`, you need to specify storage options to [configure](https://delta-io.github.io/delta-rs/usage/writing/writing-to-s3-with-locking-provider/) locking behavior. +>❗When using `s3`, you need to specify storage options to [configure](https://delta-io.github.io/delta-rs/usage/writing/writing-to-s3-with-locking-provider/) locking behavior. + +#### `get_delta_tables` helper +You can use the `get_delta_tables` helper function to get `deltalake` [DeltaTable](https://delta-io.github.io/delta-rs/api/delta_table/) objects for your Delta tables: + +```py +from dlt.common.libs.deltalake import get_delta_tables + +... + +# get dictionary of DeltaTable objects +delta_tables = get_delta_tables(pipeline) + +# execute operations on DeltaTable objects +delta_tables["my_delta_table"].optimize.compact() +delta_tables["another_delta_table"].optimize.z_order(["col_a", "col_b"]) +# delta_tables["my_delta_table"].vacuum() +# etc. + +``` ## Syncing of `dlt` state This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination). To this end, special folders and files that will be created at your destination which hold information about your pipeline state, schemas and completed loads. These folders DO NOT respect your diff --git a/docs/website/docs/dlt-ecosystem/destinations/motherduck.md b/docs/website/docs/dlt-ecosystem/destinations/motherduck.md index 9d8c8d260b..f75314bb44 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/motherduck.md +++ b/docs/website/docs/dlt-ecosystem/destinations/motherduck.md @@ -1,11 +1,10 @@ --- -title: 🧪 MotherDuck +title: MotherDuck description: MotherDuck `dlt` destination keywords: [MotherDuck, duckdb, destination, data warehouse] --- # MotherDuck -> 🧪 MotherDuck is still invitation-only and is being intensively tested. Please see the limitations/problems at the end. ## Install dlt with MotherDuck **To install the dlt library with MotherDuck dependencies:** @@ -50,11 +49,19 @@ Alternatively, you can use the connection string syntax. motherduck.credentials="md:///dlt_data_3?token=" ``` +:::tip +Motherduck now supports configurable **access tokens**. Please refer to the [documentation](https://motherduck.com/docs/key-tasks/authenticating-to-motherduck/#authentication-using-an-access-token) +::: + **4. Run the pipeline** ```sh python3 chess_pipeline.py ``` +### Motherduck connection identifier +We enable Motherduck to identify that the connection is created by `dlt`. Motherduck will use this identifier to better understand the usage patterns +associated with `dlt` integration. The connection identifier is `dltHub_dlt/DLT_VERSION(OS_NAME)`. + ## Write disposition All write dispositions are supported. @@ -64,22 +71,19 @@ By default, Parquet files and the `COPY` command are used to move files to the r The **INSERT** format is also supported and will execute large INSERT queries directly into the remote database. This method is significantly slower and may exceed the maximum query size, so it is not advised. ## dbt support -This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-duckdb](https://github.com/jwills/dbt-duckdb), which is a community-supported package. `dbt` version >= 1.5 is required (which is the current `dlt` default.) +This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-duckdb](https://github.com/jwills/dbt-duckdb), which is a community-supported package. `dbt` version >= 1.7 is required + +## Multi-statement transaction support +Motherduck supports multi-statement transactions. This change happened with `duckdb 0.10.2`. ## Syncing of `dlt` state This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination). -## Automated tests -Each destination must pass a few hundred automatic tests. MotherDuck is passing these tests (except for the transactions, of course). However, we have encountered issues with ATTACH timeouts when connecting, which makes running such a number of tests unstable. Tests on CI are disabled. +## Troubleshooting -## Troubleshooting / limitations - -### I see a lot of errors in the log like DEADLINE_EXCEEDED or Connection timed out -MotherDuck is very sensitive to the quality of the internet connection and the **number of workers used to load data**. Decrease the number of workers and ensure your internet connection is stable. We have not found any way to increase these timeouts yet. - -### MotherDuck does not support transactions. -Do not use `begin`, `commit`, and `rollback` on `dlt` **sql_client** or on the duckdb dbapi connection. It has no effect on DML statements (they are autocommit). It confuses the query engine for DDL (tables not found, etc.). -If your connection is of poor quality and you get a timeout when executing a DML query, it may happen that your transaction got executed. +### My database is attached in read only mode +ie. `Error: Invalid Input Error: Cannot execute statement of type "CREATE" on database "dlt_data" which is attached in read-only mode!` +We encountered this problem for databases created with `duckdb 0.9.x` and then migrated to `0.10.x`. After switch to `1.0.x` on Motherduck, all our databases had permission "read-only" visible in UI. We could not figure out how to change it so we dropped and recreated our databases. ### I see some exception with home_dir missing when opening `md:` connection. Some internal component (HTTPS) requires the **HOME** env variable to be present. Export such a variable to the command line. Here is what we do in our tests: @@ -88,17 +92,5 @@ os.environ["HOME"] = "/tmp" ``` before opening the connection. -### I see some watchdog timeouts. -We also see them. -```text -'ATTACH_DATABASE': keepalive watchdog timeout -``` -Our observation is that if you write a lot of data into the database, then close the connection and then open it again to write, there's a chance of such a timeout. A possible **WAL** file is being written to the remote duckdb database. - -### Invalid Input Error: Initialization function "motherduck_init" from file -Use `duckdb 0.8.1` or above. -### Motherduck connection identifier -We enable Motherduck to identify that the connection is created by `dlt`. Motherduck will use this identifier to better understand the usage patterns -associated with `dlt` integration. The connection identifier is `dltHub_dlt/DLT_VERSION(OS_NAME)`. diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md index 9475dad578..e1cd9ce88e 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md @@ -371,7 +371,7 @@ You can configure the pagination for the `posts` resource like this: { "path": "posts", "paginator": { - "type": "json_response", + "type": "json_link", "next_url_path": "pagination.next", } } @@ -380,7 +380,7 @@ You can configure the pagination for the `posts` resource like this: Alternatively, you can use the paginator instance directly: ```py -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator # ... @@ -402,8 +402,8 @@ These are the available paginators: | ------------ | -------------- | ----------- | | `json_link` | [JSONLinkPaginator](../../general-usage/http/rest-client.md#jsonresponsepaginator) | The link to the next page is in the body (JSON) of the response.
*Parameters:*
  • `next_url_path` (str) - the JSONPath to the next page URL
| | `header_link` | [HeaderLinkPaginator](../../general-usage/http/rest-client.md#headerlinkpaginator) | The links to the next page are in the response headers.
*Parameters:*
  • `link_header` (str) - the name of the header containing the links. Default is "next".
| -| `offset` | [OffsetPaginator](../../general-usage/http/rest-client.md#offsetpaginator) | The pagination is based on an offset parameter. With total items count either in the response body or explicitly provided.
*Parameters:*
  • `limit` (int) - the maximum number of items to retrieve in each request
  • `offset` (int) - the initial offset for the first request. Defaults to `0`
  • `offset_param` (str) - the name of the query parameter used to specify the offset. Defaults to "offset"
  • `limit_param` (str) - the name of the query parameter used to specify the limit. Defaults to "limit"
  • `total_path` (str) - a JSONPath expression for the total number of items. If not provided, pagination is controlled by `maximum_offset`
  • `maximum_offset` (int) - optional maximum offset value. Limits pagination even without total count
| -| `page_number` | [PageNumberPaginator](../../general-usage/http/rest-client.md#pagenumberpaginator) | The pagination is based on a page number parameter. With total pages count either in the response body or explicitly provided.
*Parameters:*
  • `base_page` (int) - the starting page number. Defaults to `0`
  • `page_param` (str) - the query parameter name for the page number. Defaults to "page"
  • `total_path` (str) - a JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page`
  • `maximum_page` (int) - optional maximum page number. Stops pagination once this page is reached
| +| `offset` | [OffsetPaginator](../../general-usage/http/rest-client.md#offsetpaginator) | The pagination is based on an offset parameter. With total items count either in the response body or explicitly provided.
*Parameters:*
  • `limit` (int) - the maximum number of items to retrieve in each request
  • `offset` (int) - the initial offset for the first request. Defaults to `0`
  • `offset_param` (str) - the name of the query parameter used to specify the offset. Defaults to "offset"
  • `limit_param` (str) - the name of the query parameter used to specify the limit. Defaults to "limit"
  • `total_path` (str) - a JSONPath expression for the total number of items. If not provided, pagination is controlled by `maximum_offset` and `stop_after_empty_page`
  • `maximum_offset` (int) - optional maximum offset value. Limits pagination even without total count
  • `stop_after_empty_page` (bool) - Whether pagination should stop when a page contains no result items. Defaults to `True`
| +| `page_number` | [PageNumberPaginator](../../general-usage/http/rest-client.md#pagenumberpaginator) | The pagination is based on a page number parameter. With total pages count either in the response body or explicitly provided.
*Parameters:*
  • `base_page` (int) - the starting page number. Defaults to `0`
  • `page_param` (str) - the query parameter name for the page number. Defaults to "page"
  • `total_path` (str) - a JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page` and `stop_after_empty_page`
  • `maximum_page` (int) - optional maximum page number. Stops pagination once this page is reached
  • `stop_after_empty_page` (bool) - Whether pagination should stop when a page contains no result items. Defaults to `True`
| | `cursor` | [JSONResponseCursorPaginator](../../general-usage/http/rest-client.md#jsonresponsecursorpaginator) | The pagination is based on a cursor parameter. The value of the cursor is in the response body (JSON).
*Parameters:*
  • `cursor_path` (str) - the JSONPath to the cursor value. Defaults to "cursors.next"
  • `cursor_param` (str) - the query parameter name for the cursor. Defaults to "after"
| | `single_page` | SinglePagePaginator | The response will be interpreted as a single-page response, ignoring possible pagination metadata. | | `auto` | `None` | Explicitly specify that the source should automatically detect the pagination method. | @@ -553,6 +553,19 @@ Available authentication types: For more complex authentication methods, you can implement a [custom authentication class](../../general-usage/http/rest-client.md#implementing-custom-authentication) and use it in the configuration. +You can use the dictionary configuration syntax also for custom authentication classes after registering them as follows: + +```py +rest_api.config_setup.register_auth("custom_auth", CustomAuth) + +{ + # ... + "auth": { + "type": "custom_auth", + "api_key": dlt.secrets["sources.my_source.my_api_key"], + } +} +``` ### Define resource relationships @@ -984,7 +997,7 @@ Some API may return 404 errors for resources that do not exist or have no data. If experiencing 401 (Unauthorized) errors, this could indicate: -- Incorrect authorization credentials. Verify credentials in the `secrets.toml`. Refer to [Secret and configs](../../general-usage/credentials/configuration#understanding-the-exceptions) for more information. +- Incorrect authorization credentials. Verify credentials in the `secrets.toml`. Refer to [Secret and configs](../../general-usage/credentials/setup#understanding-the-exceptions) for more information. - An incorrect authentication type. Consult the API documentation for the proper method. See the [authentication](#authentication) section for details. For some APIs, a [custom authentication method](../../general-usage/http/rest-client.md#custom-authentication) may be required. ### General guidelines diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md b/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md index f00e185480..85216f3206 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md @@ -96,7 +96,7 @@ For more information, read the guide on ```toml # put your secret values and credentials here. do not share this file and do not push it to github [sources.salesforce] - username = "please set me up!" # Salesforce user name + user_name = "please set me up!" # Salesforce user name password = "please set me up!" # Salesforce password security_token = "please set me up!" # Salesforce security token ``` diff --git a/docs/website/docs/general-usage/credentials/advanced.md b/docs/website/docs/general-usage/credentials/advanced.md new file mode 100644 index 0000000000..793f5c2a55 --- /dev/null +++ b/docs/website/docs/general-usage/credentials/advanced.md @@ -0,0 +1,203 @@ +--- +title: Advanced secrets and configs +description: Learn advanced hacks and tricks about credentials. +keywords: [credentials, secrets.toml, secrets, config, configuration, environment variables, provider] +--- + +`dlt` provides a lot of flexibility for managing credentials and configuration. In this section, you will learn how to correctly manage credentials in your custom sources and destinations, how the `dlt` injection mechanism works, and how to get access to configurations managed by `dlt`. + +## Injection mechanism + +`dlt` has a special treatment for functions decorated with `@dlt.source`, `@dlt.resource`, and `@dlt.destination`. When such a function is called, `dlt` takes the argument names in the signature and supplies (`injects`) the required values by looking for them in [various config providers](setup). + +### Injection rules + +1. The arguments that are passed explicitly are **never injected**. This makes the injection mechanism optional. For example, for the pipedrive source: + ```py + @dlt.source(name="pipedrive") + def pipedrive_source( + pipedrive_api_key: str = dlt.secrets.value, + since_timestamp: Optional[Union[pendulum.DateTime, str]] = "1970-01-01 00:00:00", + ) -> Iterator[DltResource]: + ... + + my_key = os.environ["MY_PIPEDRIVE_KEY"] + my_source = pipedrive_source(pipedrive_api_key=my_key) + ``` + `dlt` allows the user to specify the argument `pipedrive_api_key` explicitly if, for some reason, they do not want to use [out-of-the-box options](setup) for credentials management. + +1. Required arguments (without default values) **are never injected** and must be specified when calling. For example, for the source: + + ```py + @dlt.source + def slack_data(channels_list: List[str], api_key: str = dlt.secrets.value): + ... + ``` + The argument `channels_list` would not be injected and will output an error if it is not specified explicitly. + +1. Arguments with default values are injected if present in config providers. Otherwise, defaults from the function signature are used. For example, for the source: + + ```py + @dlt.source + def slack_source( + page_size: int = MAX_PAGE_SIZE, + access_token: str = dlt.secrets.value, + start_date: Optional[TAnyDateTime] = DEFAULT_START_DATE + ): + ... + ``` + `dlt` firstly searches for all three arguments: `page_size`, `access_token`, and `start_date` in config providers in a [specific order](setup). If it cannot find them, it will use the default values. + +1. Arguments with the special default value `dlt.secrets.value` and `dlt.config.value` **must be injected** + (or explicitly passed). If they are not found by the config providers, the code raises an + exception. The code in the functions always receives those arguments. + + Additionally, `dlt.secrets.value` tells `dlt` that the supplied value is a secret, and it will be injected only from secure config providers. + +### Add typing to your sources and resources + +We highly recommend adding types to your function signatures. +The effort is very low, and it gives `dlt` much more +information on what source/resource expects. + +Doing so provides several benefits: + +1. You'll never receive the invalid data types in your code. +1. `dlt` will automatically parse and coerce types for you, so you don't need to parse it yourself. +1. `dlt` can generate sample config and secret files for your source automatically. +1. You can request [built-in and custom credentials](complex_types) (i.e., connection strings, AWS / GCP / Azure credentials). +1. You can specify a set of possible types via `Union`, i.e., OAuth or API Key authorization. + +Let's consider the example: + +```py +@dlt.source +def google_sheets( + spreadsheet_id: str = dlt.config.value, + tab_names: List[str] = dlt.config.value, + credentials: GcpServiceAccountCredentials = dlt.secrets.value, + only_strings: bool = False +): + ... +``` + +Now, + +1. You are sure that you get a list of strings as `tab_names`. +1. You will get actual Google credentials (see [GCP Credential Configuration](complex_types#gcp-credentials)), and users can + pass them in many different forms: + + * `service.json` as a string or dictionary (in code and via config providers). + * connection string (used in SQL Alchemy) (in code and via config providers). + * if nothing is passed, the default credentials are used (i.e., those present on Cloud Function runner) + +## Toml files structure + +`dlt` arranges the sections of [toml files](setup/#secretstoml-and-configtoml) into a **default layout** that is expected by the [injection mechanism](#injection-mechanism). +This layout makes it easy to configure simple cases but also provides a room for more explicit sections and complex cases, i.e., having several sources with different credentials +or even hosting several pipelines in the same project sharing the same config and credentials. + +```text +pipeline_name + | + |-sources + |- + |- + |- {all source and resource options and secrets} + |- + |- {all source and resource options and secrets} + |- + |... + + |-extract + |- extract options for resources i.e., parallelism settings, maybe retries + |-destination + |- + |- {destination options} + |-credentials + |-{credentials options} + |-schema + |- + |-schema settings: not implemented but I'll let people set nesting level, name convention, normalizer, etc. here + |-load + |-normalize +``` + + +## Read configs and secrets manually + +`dlt` handles credentials and configuration automatically, but also offers flexibility for manual processing. +`dlt.secrets` and `dlt.config` provide dictionary-like access to configuration values and secrets, enabling any custom preprocessing if needed. +Additionally, you can store custom settings within the same configuration files. + +```py +# use `dlt.secrets` and `dlt.config` to explicitly take +# those values from providers from the explicit keys +data_source = google_sheets( + dlt.config["sheet_id"], + dlt.config["my_section.tabs"], + dlt.secrets["my_section.gcp_credentials"] +) + +data_source.run(destination="bigquery") +``` + +`dlt.config` and `dlt.secrets` behave like dictionaries from which you can request a value with any key name. `dlt` will look in all [config providers](setup) - env variables, TOML files, etc. to create these dictionaries. You can also use `dlt.config.get()` or `dlt.secrets.get()` to +request a value cast to a desired type. For example: + +```py +credentials = dlt.secrets.get("my_section.gcp_credentials", GcpServiceAccountCredentials) +``` +Creates a `GcpServiceAccountCredentials` instance out of values (typically a dictionary) under the `my_section.gcp_credentials` key. + +## Write configs and secrets in code + +`dlt.config` and `dlt.secrets` objects can also be used as setters. For example: +```py +dlt.config["sheet_id"] = "23029402349032049" +dlt.secrets["destination.postgres.credentials"] = BaseHook.get_connection('postgres_dsn').extra +``` + +Will mock the `toml` provider to desired values. + +## Example + +In the example below, the `google_sheets` source function is used to read selected tabs from Google Sheets. +It takes several arguments that specify the spreadsheet, the tab names, and the Google credentials to be used when extracting data. + +```py +@dlt.source +def google_sheets( + spreadsheet_id=dlt.config.value, + tab_names=dlt.config.value, + credentials=dlt.secrets.value, + only_strings=False +): + # Allow both a dictionary and a string passed as credentials + if isinstance(credentials, str): + credentials = json.loads(credentials) + # Allow both a list and a comma-delimited string to be passed as tabs + if isinstance(tab_names, str): + tab_names = tab_names.split(",") + sheets = build('sheets', 'v4', credentials=ServiceAccountCredentials.from_service_account_info(credentials)) + tabs = [] + for tab_name in tab_names: + data = get_sheet(sheets, spreadsheet_id, tab_name) + tabs.append(dlt.resource(data, name=tab_name)) + return tabs +``` +The `dlt.source` decorator makes all arguments in the `google_sheets` function signature configurable. +`dlt.secrets.value` and `dlt.config.value` are special argument defaults that tell `dlt` that this +argument is required and must be passed explicitly or must exist in the configuration. Additionally, +`dlt.secrets.value` tells `dlt` that an argument is a secret. + +In the example above: +- `spreadsheet_id` is a **required config** argument. +- `tab_names` is a **required config** argument. +- `credentials` is a **required secret** argument (Google Sheets credentials as a dictionary ({"private_key": ...})). +- `only_strings` is an **optional config** argument with a default value. It may be specified when calling the `google_sheets` function or included in the configuration settings. + +:::tip +`dlt.resource` behaves in the same way, so if you have a [standalone resource](../resource.md#declare-a-standalone-resource) (one that is not an inner function +of a **source**) +::: \ No newline at end of file diff --git a/docs/website/docs/general-usage/credentials/config_specs.md b/docs/website/docs/general-usage/credentials/complex_types.md similarity index 65% rename from docs/website/docs/general-usage/credentials/config_specs.md rename to docs/website/docs/general-usage/credentials/complex_types.md index 944dadb238..24915c1b2e 100644 --- a/docs/website/docs/general-usage/credentials/config_specs.md +++ b/docs/website/docs/general-usage/credentials/complex_types.md @@ -1,23 +1,23 @@ --- -title: Configuration specs -description: How to specify complex custom configurations +title: Complex credential types +description: Instructions for credentials like DB connection string. keywords: [credentials, secrets.toml, secrets, config, configuration, environment variables, specs] --- -Configuration Specs in `dlt` are Python dataclasses that define how complex configuration values, -particularly credentials, should be handled. -They specify the types, defaults, and parsing methods for these values. +## Overview -## Working with credentials (and other complex configuration values) +Often, credentials do not consist of just one `api_key`, but instead can be quite a complex structure. In this section, you'll learn how `dlt` supports different credential types and authentication options. -For example, a spec like `GcpServiceAccountCredentials` manages Google Cloud Platform -service account credentials, while `ConnectionStringCredentials` handles database connection strings. +:::tip +Learn about the authentication methods supported by the `dlt` RestAPI Client in detail in the [RESTClient section](../http/rest-client.md#authentication). +::: -### Example +`dlt` supports different credential types by providing various Python data classes called Configuration Specs. These classes define how complex configuration values, particularly credentials, should be handled. They specify the types, defaults, and parsing methods for these values. -As an example, let's use `ConnectionStringCredentials` which represents a database connection -string. +## Example with ConnectionStringCredentials + +`ConnectionStringCredentials` handles database connection strings: ```py from dlt.sources.credentials import ConnectionStringCredentials @@ -27,13 +27,11 @@ def query(sql: str, dsn: ConnectionStringCredentials = dlt.secrets.value): ... ``` -The source above executes the `sql` against database defined in `dsn`. `ConnectionStringCredentials` -makes sure you get the correct values with correct types and understands the relevant native form of -the credentials. +The source above executes the `sql` against the database defined in `dsn`. `ConnectionStringCredentials` ensures you get the correct values with the correct types and understands the relevant native form of the credentials. Below are examples of how you can set credentials in `secrets.toml` and `config.toml` files. -Example 1. Use the **dictionary** form. +### Dictionary form ```toml [dsn] @@ -43,14 +41,15 @@ username="loader" host="localhost" ``` -Example 2. Use the **native** form. +### Native form ```toml dsn="postgres://loader:loader@localhost:5432/dlt_data" ``` -Example 3. Use the **mixed** form: the password is missing in explicit dsn and will be taken from the -`secrets.toml`. +### Mixed form + +If all credentials, but the password provided explicitly in the code, `dlt` will look for the password in `secrets.toml`. ```toml dsn.password="loader" @@ -64,9 +63,9 @@ query("SELECT * FROM customers", "postgres://loader@localhost:5432/dlt_data") query("SELECT * FROM customers", {"database": "dlt_data", "username": "loader"}) ``` -## Built in credentials +## Built-in credentials -We have some ready-made credentials you can reuse: +`dlt` offers some ready-made credentials you can reuse: ```py from dlt.sources.credentials import ConnectionStringCredentials @@ -78,11 +77,7 @@ from dlt.sources.credentials import AzureCredentials ### ConnectionStringCredentials -The `ConnectionStringCredentials` class handles connection string -credentials for SQL database connections. -It includes attributes for the driver name, database name, username, password, host, port, -and additional query parameters. -This class provides methods for parsing and generating connection strings. +The `ConnectionStringCredentials` class handles connection string credentials for SQL database connections. It includes attributes for the driver name, database name, username, password, host, port, and additional query parameters. This class provides methods for parsing and generating connection strings. #### Usage ```py @@ -96,7 +91,7 @@ credentials.password = "my_password" # type: ignore credentials.host = "localhost" credentials.port = 5432 -# Convert credentials to connection string +# Convert credentials to a connection string connection_string = credentials.to_native_representation() # Parse a connection string and update credentials @@ -110,9 +105,7 @@ Above, you can find an example of how to use this spec with sources and TOML fil ### OAuth2Credentials -The `OAuth2Credentials` class handles OAuth 2.0 credentials, including client ID, -client secret, refresh token, and access token. -It also allows for the addition of scopes and provides methods for client authentication. +The `OAuth2Credentials` class handles OAuth 2.0 credentials, including client ID, client secret, refresh token, and access token. It also allows for the addition of scopes and provides methods for client authentication. Usage: ```py @@ -130,25 +123,26 @@ credentials.auth() credentials.add_scopes(["scope3", "scope4"]) ``` -`OAuth2Credentials` is a base class to implement actual OAuth, for example, -it is a base class for [GcpOAuthCredentials](#gcpoauthcredentials). +`OAuth2Credentials` is a base class to implement actual OAuth; for example, it is a base class for [GcpOAuthCredentials](#gcpoauthcredentials). ### GCP Credentials -- [GcpServiceAccountCredentials](#gcpserviceaccountcredentials). -- [GcpOAuthCredentials](#gcpoauthcredentials). +#### Examples +* [Google Analytics verified source](https://github.com/dlt-hub/verified-sources/blob/master/sources/google_analytics/__init__.py): the example of how to use GCP Credentials. +* [Google Analytics example](https://github.com/dlt-hub/verified-sources/blob/master/sources/google_analytics/setup_script_gcp_oauth.py): how you can get the refresh token using `dlt.secrets.value`. + +#### Types -[Google Analytics verified source](https://github.com/dlt-hub/verified-sources/blob/master/sources/google_analytics/__init__.py): -the example how to use GCP Credentials. +* [GcpServiceAccountCredentials](#gcpserviceaccountcredentials). +* [GcpOAuthCredentials](#gcpoauthcredentials). #### GcpServiceAccountCredentials -The `GcpServiceAccountCredentials` class manages GCP Service Account credentials. -This class provides methods to retrieve native credentials for Google clients. +The `GcpServiceAccountCredentials` class manages GCP Service Account credentials. This class provides methods to retrieve native credentials for Google clients. ##### Usage -- You may just pass the `service.json` as string or dictionary (in code and via config providers). +- You may just pass the `service.json` as a string or dictionary (in code and via config providers). - Or default credentials will be used. ```py @@ -179,7 +173,7 @@ def google_analytics( credentials_str = str(credentials) ... ``` -while `secrets.toml` looks as following: +while `secrets.toml` looks as follows: ```toml [sources.google_analytics.credentials] client_id = "client_id" # please set me up! @@ -195,11 +189,7 @@ property_id = "213025502" #### GcpOAuthCredentials -The `GcpOAuthCredentials` class is responsible for handling OAuth2 credentials for -desktop applications in Google Cloud Platform (GCP). -It can parse native values either as `GoogleOAuth2Credentials` or as -serialized OAuth client secrets JSON. -This class provides methods for authentication and obtaining access tokens. +The `GcpOAuthCredentials` class is responsible for handling OAuth2 credentials for desktop applications in Google Cloud Platform (GCP). It can parse native values either as `GoogleOAuth2Credentials` or as serialized OAuth client secrets JSON. This class provides methods for authentication and obtaining access tokens. ##### Usage ```py @@ -233,7 +223,7 @@ def google_analytics( credentials_str = str(credentials) ... ``` -while `secrets.toml` looks as following: +while `secrets.toml` looks as follows: ```toml [sources.google_analytics.credentials] client_id = "client_id" # please set me up! @@ -247,32 +237,22 @@ and `config.toml`: property_id = "213025502" ``` -In order for `auth()` method to succeed: +In order for the `auth()` method to succeed: + +- You must provide valid `client_id`, `client_secret`, `refresh_token`, and `project_id` to get a current **access token** and authenticate with OAuth. Keep in mind that the `refresh_token` must contain all the scopes that is required for your access. +- If the `refresh_token` is not provided, and you run the pipeline from a console or a notebook, `dlt` will use InstalledAppFlow to run the desktop authentication flow. -- You must provide valid `client_id` and `client_secret`, - `refresh_token` and `project_id` in order to get a current - **access token** and authenticate with OAuth. - Mind that the `refresh_token` must contain all the scopes that you require for your access. -- If `refresh_token` is not provided, and you run the pipeline from a console or a notebook, - `dlt` will use InstalledAppFlow to run the desktop authentication flow. -[Google Analytics example](https://github.com/dlt-hub/verified-sources/blob/master/sources/google_analytics/setup_script_gcp_oauth.py): how you can get the refresh token using `dlt.secrets.value`. #### Defaults -If configuration values are missing, `dlt` will use the default Google credentials (from `default()`) if available. -Read more about [Google defaults.](https://googleapis.dev/python/google-auth/latest/user-guide.html#application-default-credentials) +If configuration values are missing, `dlt` will use the default Google credentials (from `default()`) if available. Read more about [Google defaults.](https://googleapis.dev/python/google-auth/latest/user-guide.html#application-default-credentials) -- `dlt` will try to fetch the `project_id` from default credentials. - If the project id is missing, it will look for `project_id` in the secrets. - So it is normal practice to pass partial credentials (just `project_id`) and take the rest from defaults. +- `dlt` will try to fetch the `project_id` from default credentials. If the project id is missing, it will look for `project_id` in the secrets. So it is normal practice to pass partial credentials (just `project_id`) and take the rest from defaults. ### AwsCredentials -The `AwsCredentials` class is responsible for handling AWS credentials, -including access keys, session tokens, profile names, region names, and endpoint URLs. -It inherits the ability to manage default credentials and extends it with methods -for handling partial credentials and converting credentials to a botocore session. +The `AwsCredentials` class is responsible for handling AWS credentials, including access keys, session tokens, profile names, region names, and endpoint URLs. It inherits the ability to manage default credentials and extends it with methods for handling partial credentials and converting credentials to a botocore session. #### Usage ```py @@ -309,7 +289,7 @@ def aws_readers( print(aws_credentials.access_key) ... ``` -while `secrets.toml` looks as following: +while `secrets.toml` looks as follows: ```toml [sources.aws_readers.credentials] aws_access_key_id = "key_id" @@ -325,17 +305,13 @@ bucket_url = "bucket_url" #### Defaults If configuration is not provided, `dlt` uses the default AWS credentials (from `.aws/credentials`) as present on the machine: + - It works by creating an instance of botocore Session. -- If `profile_name` is specified, the credentials for that profile are used. - If not - the default profile is used. +- If `profile_name` is specified, the credentials for that profile are used. If not, the default profile is used. ### AzureCredentials -The `AzureCredentials` class is responsible for handling Azure Blob Storage credentials, -including account name, account key, Shared Access Signature (SAS) token, and SAS token permissions. -It inherits the ability to manage default credentials and extends it with methods for -handling partial credentials and converting credentials to a format suitable -for interacting with Azure Blob Storage using the adlfs library. +The `AzureCredentials` class is responsible for handling Azure Blob Storage credentials, including account name, account key, Shared Access Signature (SAS) token, and SAS token permissions. It inherits the ability to manage default credentials and extends it with methods for handling partial credentials and converting credentials to a format suitable for interacting with Azure Blob Storage using the adlfs library. #### Usage ```py @@ -363,7 +339,7 @@ def azure_readers( # to_native_credentials() is not yet implemented ... ``` -while `secrets.toml` looks as following: +while `secrets.toml` looks as follows: ```toml [sources.azure_readers.credentials] azure_storage_account_name = "account_name" @@ -374,60 +350,58 @@ and `config.toml`: [sources.azure_readers] bucket_url = "bucket_url" ``` + #### Defaults If configuration is not provided, `dlt` uses the default credentials using `DefaultAzureCredential`. ## Working with alternatives of credentials (Union types) -If your source/resource allows for many authentication methods, you can support those seamlessly for -your user. The user just passes the right credentials and `dlt` will inject the right type into your -decorated function. +If your source/resource allows for many authentication methods, you can support those seamlessly for your user. The user just passes the right credentials, and `dlt` will inject the right type into your decorated function. Example: ```py @dlt.source def zen_source(credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, str] = dlt.secrets.value, some_option: bool = False): - # depending on what the user provides in config, ZenApiKeyCredentials or ZenEmailCredentials will be injected in `credentials` argument - # both classes implement `auth` so you can always call it + # Depending on what the user provides in config, ZenApiKeyCredentials or ZenEmailCredentials will be injected in the `credentials` argument. Both classes implement `auth` so you can always call it. credentials.auth() return dlt.resource([credentials], name="credentials") -# pass native value +# Pass native value os.environ["CREDENTIALS"] = "email:mx:pwd" assert list(zen_source())[0].email == "mx" -# pass explicit native value +# Pass explicit native value assert list(zen_source("secret:🔑:secret"))[0].api_secret == "secret" - # pass explicit dict assert list(zen_source(credentials={"email": "emx", "password": "pass"}))[0].email == "emx" - ``` -> This applies not only to credentials but to all specs (see next chapter). +:::info +This applies not only to credentials but to [all specs](#writing-custom-specs). +::: -Read the [whole test](https://github.com/dlt-hub/dlt/blob/devel/tests/common/configuration/test_spec_union.py), it shows how to create unions +:::tip +Check out the [complete example](https://github.com/dlt-hub/dlt/blob/devel/tests/common/configuration/test_spec_union.py), to learn how to create unions of credentials that derive from the common class, so you can handle it seamlessly in your code. +::: ## Writing custom specs -**specs** let you take full control over the function arguments: +**Custom specifications** let you take full control over the function arguments. You can -- Which values should be injected, the types, default values. -- You can specify optional and final fields. +- Control which values should be injected, the types, default values. +- Specify optional and final fields. - Form hierarchical configurations (specs in specs). - Provide own handlers for `on_partial` (called before failing on missing config key) or `on_resolved`. - Provide own native value parsers. - Provide own default credentials logic. -- Adds all Python dataclass goodies to it. -- Adds all Python `dict` goodies to it (`specs` instances can be created from dicts and serialized +- Utilise Python dataclass functionality. +- Utilise Python `dict` functionality (`specs` instances can be created from dicts and serialized from dicts). -This is used a lot in the `dlt` core and may become useful for complicated sources. - -In fact, for each decorated function a spec is synthesized. In case of `google_sheets` following +In fact, `dlt` synthesizes a unique spec for each decorated function. For example, in the case of `google_sheets`, the following class is created: ```py @@ -435,7 +409,7 @@ from dlt.sources.config import configspec, with_config @configspec class GoogleSheetsConfiguration(BaseConfiguration): - tab_names: List[str] = None # manadatory + tab_names: List[str] = None # mandatory credentials: GcpServiceAccountCredentials = None # mandatory secret only_strings: Optional[bool] = False ``` diff --git a/docs/website/docs/general-usage/credentials/config_providers.md b/docs/website/docs/general-usage/credentials/config_providers.md deleted file mode 100644 index 3dbe88893b..0000000000 --- a/docs/website/docs/general-usage/credentials/config_providers.md +++ /dev/null @@ -1,194 +0,0 @@ ---- -title: Configuration providers -description: Where dlt looks for config/secrets and in which order. -keywords: [credentials, secrets.toml, secrets, config, configuration, environment - variables, provider] ---- - -Configuration Providers in the context of the `dlt` library -refer to different sources from which configuration values -and secrets can be retrieved for a data pipeline. -These providers form a hierarchy, with each having its own -priority in determining the values for function arguments. - -## The provider hierarchy - -If function signature has arguments that may be injected, `dlt` looks for the argument values in -providers. - -### Providers - -1. **Environment Variables**: At the top of the hierarchy are environment variables. - If a value for a specific argument is found in an environment variable, - dlt will use it and will not proceed to search in lower-priority providers. - -2. **Vaults (Airflow/Google/AWS/Azure)**: These are specialized providers that come - after environment variables. They can provide configuration values and secrets. - However, they typically focus on handling sensitive information. - -3. **`secrets.toml` and `config.toml` Files**: These files are used for storing both - configuration values and secrets. `secrets.toml` is dedicated to sensitive information, - while `config.toml` contains non-sensitive configuration data. - -4. Custom Providers added with `register_provider`: These are your own Provider implementation - you can use to connect to any backend. See [adding custom providers](#adding-custom-providers) for more information. - -5. **Default Argument Values**: These are the values specified in the function's signature. - They have the lowest priority in the provider hierarchy. - - -### Example - -```py -@dlt.source -def google_sheets( - spreadsheet_id=dlt.config.value, - tab_names=dlt.config.value, - credentials=dlt.secrets.value, - only_strings=False -): - ... -``` - -In case of `google_sheets()` it will look -for: `spreadsheet_id`, `tab_names`, `credentials` and `only_strings` - -Each provider has its own key naming convention, and dlt is able to translate between them. - -**The argument name is a key in the lookup**. - -At the top of the hierarchy are Environment Variables, then `secrets.toml` and -`config.toml` files. Providers like Airflow/Google/AWS/Azure Vaults will be inserted **after** the Environment -provider but **before** TOML providers. - -For example, if `spreadsheet_id` is found in environment variable `SPREADSHEET_ID`, `dlt` will not look in TOML files -and below. - -The values passed in the code **explicitly** are the **highest** in provider hierarchy. The **default values** -of the arguments have the **lowest** priority in the provider hierarchy. - -:::info -Explicit Args **>** ENV Variables **>** Vaults: Airflow etc. **>** `secrets.toml` **>** `config.toml` **>** Default Arg Values -::: - -Secrets are handled only by the providers supporting them. Some providers support only -secrets (to reduce the number of requests done by `dlt` when searching sections). - -1. `secrets.toml` and environment may hold both config and secret values. -1. `config.toml` may hold only config values, no secrets. -1. Various vaults providers hold only secrets, `dlt` skips them when looking for values that are not - secrets. - -:::info -Context-aware providers will activate in the right environments i.e. on Airflow or AWS/GCP VMachines. -::: - -### Adding Custom Providers - -You can use the `CustomLoaderDocProvider` classes to supply a custom dictionary obtained from any source to dlt for use -as a source of `config` and `secret` values. The code below demonstrates how to use a config stored in config.json. - -```py -import dlt - -from dlt.common.configuration.providers import CustomLoaderDocProvider - -# create a function that loads a dict -def load_config(): - with open("config.json", "rb") as f: - config_dict = json.load(f) - -# create the custom provider -provider = CustomLoaderDocProvider("my_json_provider",load_config) - -# register provider -dlt.config.register_provider(provider) -``` - -:::tip -Check our nice [example](../../examples/custom_config_provider) for a `yaml` based config provider that supports switchable profiles. -::: - -## Provider key formats - -### TOML vs. Environment Variables - -Providers may use different formats for the keys. `dlt` will translate the standard format where -sections and key names are separated by "." into the provider-specific formats. - -1. For TOML, names are case-sensitive and sections are separated with ".". -1. For Environment Variables, all names are capitalized and sections are separated with double - underscore "__". - - Example: To override a token in "secrets.toml": - - ```toml - [sources.github] - access_token = "GITHUB_API_TOKEN" - ``` - Use the following environment variable: - ```sh - export SOURCES__GITHUB__ACCESS_TOKEN="your_token_here" - ``` - - -1. When `dlt` evaluates the request `dlt.secrets["my_section.gcp_credentials"]` it must find the `private_key` for Google credentials. It looks for it as follows: - 1. It first searches for environment variable `MY_SECTION__GCP_CREDENTIALS__PRIVATE_KEY` and if not found, - 1. in `secrets.toml` file under `my_section.gcp_credentials.private_key`. - - This way, `dlt` prioritizes security by using environment variables before looking into configuration files. - - - -:::info -While using Google secrets provider please make sure your pipeline name -contains no whitespace or any other punctuation characters except "-" and "_". - -Per Google the secret name can contain - - 1. Uppercase and lowercase letters, - 2. Numerals, - 3. Hyphens, - 4. Underscores. -::: - -### Environment provider - -Looks for the values in the environment variables. - -### TOML provider - -The TOML provider in dlt utilizes two TOML files: - -- `secrets.toml `- This file is intended for storing sensitive information, often referred to as "secrets". -- `config.toml `- This file is used for storing configuration values. - -By default, the `.gitignore` file in the project prevents `secrets.toml` from being added to -version control and pushed. However, `config.toml` can be freely added to version control. - -:::info -**TOML provider always loads those files from `.dlt` folder** which is looked **relative to the -current Working Directory**. -::: - -Example: If your working directory is `my_dlt_project` and your project has the following structure: - -```text -my_dlt_project: - | - pipelines/ - |---- .dlt/secrets.toml - |---- google_sheets.py -``` - -and you run `python pipelines/google_sheets.py` then `dlt` will look for `secrets.toml` in -`my_dlt_project/.dlt/secrets.toml` and ignore the existing -`my_dlt_project/pipelines/.dlt/secrets.toml`. - -If you change your working directory to `pipelines` and run `python google_sheets.py` it will look for -`my_dlt_project/pipelines/.dlt/secrets.toml` as (probably) expected. - -:::caution -It's worth mentioning that the TOML provider also has the capability to read files from `~/.dlt/` -(located in the user's home directory) in addition to the local project-specific `.dlt` folder. -::: \ No newline at end of file diff --git a/docs/website/docs/general-usage/credentials/configuration.md b/docs/website/docs/general-usage/credentials/configuration.md deleted file mode 100644 index 8ed6add2c2..0000000000 --- a/docs/website/docs/general-usage/credentials/configuration.md +++ /dev/null @@ -1,485 +0,0 @@ ---- -title: Secrets and configs -description: What are secrets and configs and how sources and destinations read them. -keywords: [credentials, secrets.toml, secrets, config, configuration, environment - variables] ---- - -Use secret and config values to pass access credentials and configure or fine-tune your pipelines without the need to modify your code. -When done right you'll be able to run the same pipeline script during development and in production. - -**Configs**: - - - Configs refer to non-sensitive configuration data. These are settings, parameters, or options that define the behavior of a data pipeline. - - They can include things like file paths, database hosts and timeouts, API urls, performance settings, or any other settings that affect the pipeline's behavior. - -**Secrets**: - - - Secrets are sensitive information that should be kept confidential, such as passwords, API keys, private keys, and other confidential data. - - It's crucial to never hard-code secrets directly into the code, as it can pose a security risk. - -## Configure dlt sources and resources - -In the example below, the `google_sheets` source function is used to read selected tabs from Google Sheets. -It takes several arguments that specify the spreadsheet, the tab names and the Google credentials to be used when extracting data. - -```py -@dlt.source -def google_sheets( - spreadsheet_id=dlt.config.value, - tab_names=dlt.config.value, - credentials=dlt.secrets.value, - only_strings=False -): - # Allow both dictionary and a string passed as credentials - if isinstance(credentials, str): - credentials = json.loads(credentials) - # Allow both list and comma delimited string to be passed as tabs - if isinstance(tab_names, str): - tab_names = tab_names.split(",") - sheets = build('sheets', 'v4', credentials=ServiceAccountCredentials.from_service_account_info(credentials)) - tabs = [] - for tab_name in tab_names: - data = get_sheet(sheets, spreadsheet_id, tab_name) - tabs.append(dlt.resource(data, name=tab_name)) - return tabs -``` -`dlt.source` decorator makes all arguments in `google_sheets` function signature configurable. -`dlt.secrets.value` and `dlt.config.value` are special argument defaults that tell `dlt` that this -argument is required and must be passed explicitly or must exist in the configuration. Additionally -`dlt.secrets.value` tells `dlt` that an argument is a secret. - -In the example above: -- `spreadsheet_id`: is a **required config** argument. -- `tab_names`: is a **required config** argument. -- `credentials`: is a **required secret** argument (Google Sheets credentials as a dictionary ({"private_key": ...})). -- `only_strings`: is an **optional config** argument with a default value. It may be specified when calling the `google_sheets` function or included in the configuration settings. - -:::tip -`dlt.resource` behaves in the same way so if you have a [standalone resource](../resource.md#declare-a-standalone-resource) (one that is not an inner function -of a **source**) -::: - -### Allow `dlt` to pass the config and secrets automatically -You are free to call the function above as usual and pass all the arguments in the code. You'll hardcode google credentials and [we do not recommend that](#do-not-pass-hardcoded-secrets). - -Instead let `dlt` to do the work and leave it to [injection mechanism](#injection-mechanism) that looks for function arguments in the config files or environment variables and adds them to your explicit arguments during a function call. Below are two most typical examples: - -1. Pass spreadsheet id and tab names in the code, inject credentials from the secrets: - ```py - data_source = google_sheets("23029402349032049", ["tab1", "tab2"]) - ``` - `credentials` value will be injected by the `@source` decorator (e.g. from `secrets.toml`). - `spreadsheet_id` and `tab_names` take values from the call arguments. - -2. Inject all the arguments from config / secrets - ```py - data_source = google_sheets() - ``` - `credentials` value will be injected by the `@source` decorator (e.g. from **secrets.toml**). - - `spreadsheet_id` and `tab_names` will be also injected by the `@source` decorator (e.g. from **config.toml**). - - -Where do the configs and secrets come from? By default, `dlt` looks in two **config providers**: - -- [TOML files](config_providers#toml-provider): - - Configs are kept in **.dlt/config.toml**. `dlt` will match argument names with - entries in the file and inject the values: - - ```toml - spreadsheet_id="1HhWHjqouQnnCIZAFa2rL6vT91YRN8aIhts22SUUR580" - tab_names=["tab1", "tab2"] - ``` - Secrets in **.dlt/secrets.toml**. `dlt` will look for `credentials`, - ```toml - [credentials] - client_email = "" - private_key = "" - project_id = "" - ``` - Note that **credentials** will be evaluated as dictionary containing **client_email**, **private_key** and **project_id** as keys. It is standard TOML behavior. -- [Environment Variables](config_providers#environment-provider): - ```toml - CREDENTIALS="" - SPREADSHEET_ID="1HhWHjqouQnnCIZAFa2rL6vT91YRN8aIhts22SUUR580" - TAB_NAMES=["tab1", "tab2"] - ``` - We pass the JSON contents of `service.json` file to `CREDENTIALS` and we specify tab names as comma-delimited values. Environment variables are always in **upper case**. - -:::tip -There are many ways you can organize your configs and secrets. The example above is the simplest default **layout** that `dlt` supports. In more complicated cases (i.e. a single configuration is shared by many pipelines with different sources and destinations) you may use more [explicit layouts](#secret-and-config-values-layout-and-name-lookup). -::: - -:::caution -**[TOML provider](config_providers#toml-provider) always loads `secrets.toml` and `config.toml` files from `.dlt` folder** which is looked relative to the -**current [Working Directory](https://en.wikipedia.org/wiki/Working_directory)**. TOML provider also has the capability to read files from `~/.dlt/` -(located in the user's [Home Directory](https://en.wikipedia.org/wiki/Home_directory)). -::: - -### Do not hardcode secrets -You should never do that. Sooner or later your private key will leak. - -```py -# WRONG!: -# provide all values directly - wrong but possible. -# secret values should never be present in the code! -data_source = google_sheets( - "23029402349032049", - ["tab1", "tab2"], - credentials={"private_key": ""} -) -``` - -### Pass secrets in code from external providers -You can get the secret values from your own providers. Below we take **credentials** for our `google_sheets` source from Airflow base hook: - -```py -from airflow.hooks.base_hook import BaseHook - -# get it from airflow connections or other credential store -credentials = BaseHook.get_connection('gcp_credentials').extra -data_source = google_sheets(credentials=credentials) -``` - -## Configure a destination -We provide detailed guides for [built-in destinations] and [explain how to configure them in code](../destination.md#configure-a-destination) (including credentials) - - -## Add typing to your sources and resources - -We highly recommend adding types to your function signatures. -The effort is very low, and it gives `dlt` much more -information on what source/resource expects. - -Doing so provides several benefits: - -1. You'll never receive invalid data types in your code. -1. `dlt` will automatically parse and coerce types for you. In our example, you do not need to parse list of tabs or credentials dictionary yourself. -1. We can generate nice sample config and secret files for your source. -1. You can request [built-in and custom credentials](config_specs.md) (i.e. connection strings, AWS / GCP / Azure credentials). -1. You can specify a set of possible types via `Union` i.e. OAuth or API Key authorization. - -```py -@dlt.source -def google_sheets( - spreadsheet_id: str = dlt.config.value, - tab_names: List[str] = dlt.config.value, - credentials: GcpServiceAccountCredentials = dlt.secrets.value, - only_strings: bool = False -): - ... -``` - -Now: - -1. You are sure that you get a list of strings as `tab_names`. -1. You will get actual Google credentials (see [GCP Credential Configuration](config_specs#gcp-credentials)), and your users can - pass them in many different forms. - -In case of `GcpServiceAccountCredentials`: - -- You may just pass the `service.json` as string or dictionary (in code and via config providers). -- You may pass a connection string (used in SQL Alchemy) (in code and via config providers). -- If you do not pass any credentials, the default credentials are used (i.e. those present on Cloud Function runner) - -## Read configs and secrets yourself -`dlt.secrets` and `dlt.config` provide dictionary-like access to configuration values and secrets, respectively. - -```py -# use `dlt.secrets` and `dlt.config` to explicitly take -# those values from providers from the explicit keys -data_source = google_sheets( - dlt.config["sheet_id"], - dlt.config["my_section.tabs"], - dlt.secrets["my_section.gcp_credentials"] -) - -data_source.run(destination="bigquery") -``` -`dlt.config` and `dlt.secrets` behave like dictionaries from which you can request a value with any key name. `dlt` will look in all [config providers](#injection-mechanism) - TOML files, env variables etc. just like it does with the standard section layout. You can also use `dlt.config.get()` or `dlt.secrets.get()` to -request value cast to a desired type. For example: -```py -credentials = dlt.secrets.get("my_section.gcp_credentials", GcpServiceAccountCredentials) -``` -Creates `GcpServiceAccountCredentials` instance out of values (typically a dictionary) under **my_section.gcp_credentials** key. - -### Write configs and secrets in code -**dlt.config** and **dlt.secrets** can be also used as setters. For example: -```py -dlt.config["sheet_id"] = "23029402349032049" -dlt.secrets["destination.postgres.credentials"] = BaseHook.get_connection('postgres_dsn').extra -``` -Will mock the **toml** provider to desired values. - - -## Injection mechanism - -Config and secret values are added to the function arguments when a function decorated with `@dlt.source` or `@dlt.resource` is called. - -The signature of such function (i.e. `google_sheets` above) is **also a specification of the configuration**. -During runtime `dlt` takes the argument names in the signature and supplies (`inject`) the required values via various config providers. - -The injection rules are: - -1. If you call the decorated function, the arguments that are passed explicitly are **never injected**, - this makes the injection mechanism optional. - -1. Required arguments (without default values) **are never injected** and must be specified when calling. - -1. Arguments with default values are injected if present in config providers, otherwise defaults from function signature is used. - -1. Arguments with the special default value `dlt.secrets.value` and `dlt.config.value` **must be injected** - (or explicitly passed). If they are not found by the config providers, the code raises - exception. The code in the functions always receives those arguments. - -Additionally `dlt.secrets.value` tells `dlt` that supplied value is a secret, and it will be injected -only from secure config providers. - -## Secret and config values layout and name lookup - -`dlt` uses a layout of hierarchical sections to organize the config and secret values. This makes -configurations and secrets easy to manage, and disambiguate values with the same keys by placing -them in the different sections. - -:::note -If you know how TOML files are organized -> this is the same concept! -::: - -A lot of config values are dictionaries themselves (i.e. most of the credentials) and you want the -values corresponding to one component to be close together. - -You can have a separate credentials for your destinations and each of the sources your pipeline uses, -if you have many pipelines in a single project, you can group them in separate sections. - -Here is the simplest default layout for our `google_sheets` example. - -### Default layout without sections - -**secrets.toml** - -```toml -[credentials] -client_email = "" -private_key = "" -project_id = "" -``` - -**config.toml** - -```toml -tab_names=["tab1", "tab2"] -``` - -As you can see the details of GCP credentials are placed under `credentials` which is argument name -to source function. - -### Default layout with explicit sections -This makes sure that `google_sheets` source does not share any secrets and configs with any other source or destination. - -**secrets.toml** - -```toml -[sources.google_sheets.credentials] -client_email = "" -private_key = "" -project_id = "" -``` - -**config.toml** - -```toml -[sources.google_sheets] -tab_names=["tab1", "tab2"] -``` - -### Custom layout - -Use this if you want to read and pass the config/secrets yourself - -**secrets.toml** - -```toml -[my_section] - -[my_section.gcp_credentials] -client_email = "" -private_key = "" -``` - -**config.toml** - -```toml -[my_section] -tabs=["tab1", "tab2"] - -[my_section.gcp_credentials] -# I prefer to keep my project id in config file and private key in secrets -project_id = "" -``` - -### Default layout and default key lookup during injection - -`dlt` arranges the sections into **default layout** that is expected by injection mechanism. This layout -makes it easy to configure simple cases but also provides a room for more explicit sections and -complex cases i.e. having several sources with different credentials or even hosting several pipelines -in the same project sharing the same config and credentials. - -```text -pipeline_name - | - |-sources - |- - |- - |- {all source and resource options and secrets} - |- - |- {all source and resource options and secrets} - |- - |... - - |-extract - |- extract options for resources ie. parallelism settings, maybe retries - |-destination - |- - |- {destination options} - |-credentials - |-{credentials options} - |-schema - |- - |-schema settings: not implemented but I'll let people set nesting level, name convention, normalizer etc. here - |-load - |-normalize -``` - -Lookup rules: - -**Rule 1:** The lookup starts with the most specific possible path, and if value is not found there, -it removes the right-most section and tries again. - -Example: We use the `bigquery` destination and the `google_sheets` source. They both use google credentials and expect them to be configured under `credentials` key. - -1. If we create just a single `credentials` section like in [here](#default-layout-without-sections), destination and source will share the same credentials. - -2. If we define sections as below, we'll keep the credentials separate - -```toml -# google sheet credentials -[sources.credentials] -client_email = "" -private_key = "" -project_id = "" - -# bigquery credentials -[destination.credentials] -client_email = "" -private_key = "" -project_id = "" -``` - -Now when `dlt` looks for destination credentials, it will start with `destination.bigquery.credentials`, eliminate `bigquery` and stop at `destination.credentials`. - -When looking for `sources` credentials it will start with `sources.google_sheets.google_sheets.credentials`, eliminate `google_sheets` twice and stop at `sources.credentials` (we assume that `google_sheets` source was defined in `google_sheets` python module) - -Example: let's be even more explicit and use a full section path possible. - -```toml -# google sheet credentials -[sources.google_sheets.credentials] -client_email = "" -private_key = "" -project_id = "" - -# google analytics credentials -[sources.google_analytics.credentials] -client_email = "" -private_key = "" -project_id = "" - -# bigquery credentials -[destination.bigquery.credentials] -client_email = "" -private_key = "" -project_id = "" -``` - -Now we can separate credentials for different sources as well. - -**Rule 2:** You can use your pipeline name to have separate configurations for each pipeline in your -project. - -Pipeline created/obtained with `dlt.pipeline()` creates a global and optional namespace with the -value of `pipeline_name`. All config values will be looked with pipeline name first and then again -without it. - -Example: the pipeline is named `ML_sheets`. - -```toml -[ML_sheets.credentials] -client_email = "" -private_key = "" -project_id = "" -``` - -or maximum path: - -```toml -[ML_sheets.sources.google_sheets.credentials] -client_email = "" -private_key = "" -project_id = "" -``` - -### The `sources` section - -Config and secrets for decorated sources and resources are kept in -`sources..` section. **All sections are optional during lookup**. For example, -if source module is named `pipedrive` and the function decorated with `@dlt.source` is -`deals(api_key: str=...)` then `dlt` will look for API key in: - -1. `sources.pipedrive.deals.api_key` -1. `sources.pipedrive.api_key` -1. `sources.api_key` -1. `api_key` - -Step 2 in a search path allows all the sources/resources in a module to share the same set of -credentials. - -Also look at the [following test](https://github.com/dlt-hub/dlt/blob/devel/tests/extract/test_decorators.py#L303) `test_source_sections`. - -## Understanding the exceptions - -Now we can finally understand the `ConfigFieldMissingException`. - -Let's run `chess.py` example without providing the password: - -```sh -$ CREDENTIALS="postgres://loader@localhost:5432/dlt_data" python chess.py -... -dlt.common.configuration.exceptions.ConfigFieldMissingException: Following fields are missing: ['password'] in configuration with spec PostgresCredentials - for field "password" config providers and keys were tried in following order: - In Environment Variables key CHESS_GAMES__DESTINATION__POSTGRES__CREDENTIALS__PASSWORD was not found. - In Environment Variables key CHESS_GAMES__DESTINATION__CREDENTIALS__PASSWORD was not found. - In Environment Variables key CHESS_GAMES__CREDENTIALS__PASSWORD was not found. - In secrets.toml key chess_games.destination.postgres.credentials.password was not found. - In secrets.toml key chess_games.destination.credentials.password was not found. - In secrets.toml key chess_games.credentials.password was not found. - In Environment Variables key DESTINATION__POSTGRES__CREDENTIALS__PASSWORD was not found. - In Environment Variables key DESTINATION__CREDENTIALS__PASSWORD was not found. - In Environment Variables key CREDENTIALS__PASSWORD was not found. - In secrets.toml key destination.postgres.credentials.password was not found. - In secrets.toml key destination.credentials.password was not found. - In secrets.toml key credentials.password was not found. -Please refer to https://dlthub.com/docs/general-usage/credentials for more information -``` - -It tells you exactly which paths `dlt` looked at, via which config providers and in which order. - -In the example above: - -1. First it looked in a big section `chess_games` which is name of the pipeline. -1. In each case it starts with full paths and goes to minimum path `credentials.password`. -1. First it looks into `environ` then in `secrets.toml`. It displays the exact keys tried. -1. Note that `config.toml` was skipped! It may not contain any secrets. - -Read more about [Provider Hierarchy](./config_providers). \ No newline at end of file diff --git a/docs/website/docs/general-usage/credentials/index.md b/docs/website/docs/general-usage/credentials/index.md new file mode 100644 index 0000000000..c9cbe6707c --- /dev/null +++ b/docs/website/docs/general-usage/credentials/index.md @@ -0,0 +1,18 @@ +--- +title: Configuration and Secrets +description: How to configure dlt pipelines and set up credentials +keywords: [credentials, secrets.toml, secrets, config, configuration, environment variables] +--- +import DocCardList from '@theme/DocCardList'; + +`dlt` pipelines usually require configurations and credentials. These can be set up in [various ways](setup): + +1. Environment variables +2. Configuration files (`secrets.toml` and `config.toml`) +3. Key managers and Vaults + +`dlt` automatically extracts configuration settings and secrets based on flexible [naming conventions](setup/#naming-convention). It then [injects](advanced/#injection-mechanism) these values where needed in code. + +# Learn Details About + + \ No newline at end of file diff --git a/docs/website/docs/general-usage/credentials/setup.md b/docs/website/docs/general-usage/credentials/setup.md new file mode 100644 index 0000000000..4ab9149bc0 --- /dev/null +++ b/docs/website/docs/general-usage/credentials/setup.md @@ -0,0 +1,613 @@ +--- +title: How to set up credentials +description: Where and in which order dlt looks for config/secrets. +keywords: [credentials, secrets.toml, secrets, config, configuration, environment + variables, provider] +--- + +`dlt` automatically extracts configuration settings and secrets based on flexible [naming conventions](setup/#naming-convention). + +It then [injects](advanced/#injection-mechanism) these values where needed in functions decorated with `@dlt.source`, `@dlt.resource`, or `@dlt.destination`. + +:::note +* **Configuration** refers to non-sensitive settings that define a data pipeline's behavior. These include file paths, database hosts, timeouts, API URLs, and performance settings. +* **Secrets** are sensitive data like passwords, API keys, and private keys. They should never be hard-coded to avoid security risks. +::: + +## Available config providers + +There are multiple ways to define configurations and credentials for your pipelines. `dlt` looks for these definitions in the following order during pipeline execution: + +1. [Environment Variables](#environment-variables): If a value for a specific argument is found in an environment variable, dlt will use it and will not proceed to search in lower-priority providers. + +1. [Vaults](#vaults): Credentials specified in vaults like Google Secrets Manager, Azure Key Vault, AWS Secrets Manager. + +1. [secrets.toml and config.toml files](#secretstoml-and-configtoml): These files are used for storing both configuration values and secrets. `secrets.toml` is dedicated to sensitive information, while `config.toml` contains non-sensitive configuration data. + +1. [Custom Providers](#custom-providers) added with `register_provider`: This is a custom provider implementation you can design yourself. +A custom config provider is helpful if you want to use your own configuration file structure or perform advanced preprocessing of configs and secrets. + +1. [Default Argument Values](advanced#ingestion-mechanism): These are the values specified in the function's signature. + +:::tip +Please make sure your pipeline name contains no whitespace or any other punctuation characters except `"-"` and `"_"`. This way you will ensure your code is working with any configuration option. +::: + +## Naming convention + +`dlt` uses a specific naming hierarchy to search for the secrets and configs values. This makes configurations and secrets easy to manage. + +To keep the naming convention flexible, `dlt` looks for a lot of possible combinations of key names, starting from the most specific possible path. Then, if the value is not found, it removes the right-most section and tries again. + +* The most specific possible path for **sources** looks like: +```sh +.sources... +``` + +* The most specific possible path for **destinations** looks like: +```sh +.destination..credentials. +``` + +### Example + +For example, if the source module is named `pipedrive` and the source is defined as follows: + +```py +# pipedrive.py + +@dlt.source +def deals(api_key: str = dlt.secrets.value): + pass +``` + +`dlt` will search for the following names in this order: + +1. `sources.pipedrive.deals.api_key` +1. `sources.pipedrive.api_key` +1. `sources.api_key` +1. `api_key` + +:::tip +You can use your pipeline name to have separate configurations for each pipeline in your project. All config values will be looked with the pipeline name first and then again without it. + +```toml +[pipeline_name_1.sources.google_sheets.credentials] +client_email = "" +private_key = "" +project_id = "" + +[pipeline_name_2.sources.google_sheets.credentials] +client_email = "" +private_key = "" +project_id = "" +``` +::: + +### Credential types + +In most cases, credentials are just key-value pairs, but in some cases, the actual structure of [credentials](complex_types) could be quite complex and support several ways of setting it up. +For example, to connect to a `sql_database` source, you can either set up a connection string: + +```toml +[sources.sql_database] +credentials="snowflake://user:password@service-account/database?warehouse=warehouse_name&role=role" +``` +or set up all parameters of connection separately: + +```toml +[sources.sql_database.credentials] +drivername="snowflake" +username="user" +password="password" +database = "database" +host = "service-account" +warehouse = "warehouse_name" +role = "role" +``` + +`dlt` can work with both ways and convert one to another. To learn more about which credential types are supported, visit the [complex credential types](complex_types) page. + +## Environment variables + +`dlt` prioritizes security by looking in environment variables before looking into the .toml files. + +The format of lookup keys is slightly different from secrets files because for environment variables, all names are capitalized, and sections are separated with a double underscore `"__"`. For example, to specify the Facebook Ads access token through environment variables, you would need to set up: + +```sh +export SOURCES__FACEBOOK_ADS__ACCESS_TOKEN="" +``` + +Check out the [example](#examples) of setting up credentials through environment variables. + +:::tip +To organize development and securely manage environment variables for credentials storage, you can use the [python-dotenv](https://pypi.org/project/python-dotenv/) to automatically load variables from an `.env` file. +::: + +## Vaults + +Vault integration methods vary based on the vault type. Check out our example involving [Google Cloud Secrets Manager](../../walkthroughs/add_credentials.md#retrieving-credentials-from-google-cloud-secret-manager). +For other vault integrations, you are welcome to [contact sales](https://dlthub.com/contact-sales) to learn about our [building blocks for data platform teams](https://dlthub.com/product/data-platform-teams#secure). + +## secrets.toml and config.toml + +The TOML config provider in dlt utilizes two TOML files: + +`config.toml`: + +- Configs refer to non-sensitive configuration data. These are settings, parameters, or options that define the behavior of a data pipeline. +- They can include things like file paths, database hosts and timeouts, API URLs, performance settings, or any other settings that affect the pipeline's behavior. +- Accessible in code through `dlt.config.values` + +`secrets.toml`: + +- Secrets are sensitive information that should be kept confidential, such as passwords, API keys, private keys, and other confidential data. +- It's crucial to never hard-code secrets directly into the code, as it can pose a security risk. +- Accessible in code through `dlt.secrets.values` + +By default, the `.gitignore` file in the project prevents `secrets.toml` from being added to version control and pushed. However, `config.toml` can be freely added to version control. + +### Location + +The TOML provider always loads those files from the `.dlt` folder, located **relative** to the current working directory. + +For example, if your working directory is `my_dlt_project` and your project has the following structure: + +```text +my_dlt_project: + | + pipelines/ + |---- .dlt/secrets.toml + |---- google_sheets.py +``` + +and you run +```sh +python pipelines/google_sheets.py +``` +then `dlt` will look for secrets in `my_dlt_project/.dlt/secrets.toml` and ignore the existing `my_dlt_project/pipelines/.dlt/secrets.toml`. + +If you change your working directory to `pipelines` and run +```sh +python google_sheets.py +``` + +`dlt` will look for `my_dlt_project/pipelines/.dlt/secrets.toml` as (probably) expected. + +:::caution +The TOML provider also has the capability to read files from `~/.dlt/` (located in the user's home directory) in addition to the local project-specific `.dlt` folder. +::: + +### Structure + +`dlt` organizes sections in TOML files in a specific structure required by the [injection mechanism](advanced/#injection-mechanism). +Understanding this structure gives you more flexibility in setting credentials. For more details, see [Toml files structure](advanced/#toml-files-structure). + +## Custom Providers + +You can use the `CustomLoaderDocProvider` classes to supply a custom dictionary to `dlt` for use +as a supplier of `config` and `secret` values. The code below demonstrates how to use a config stored in `config.json`. + +```py +import dlt + +from dlt.common.configuration.providers import CustomLoaderDocProvider + +# create a function that loads a dict +def load_config(): + with open("config.json", "rb") as f: + config_dict = json.load(f) + +# create the custom provider +provider = CustomLoaderDocProvider("my_json_provider",load_config) + +# register provider +dlt.config.register_provider(provider) +``` + +:::tip +Check our an [example](../../examples/custom_config_provider) for a `yaml` based config provider that supports switchable profiles. +::: + +## Examples + +### Setup both configurations and secrets + +`dlt` recognizes two types of data: secrets and configurations. The main difference is that secrets contain sensitive information, +while configurations hold non-sensitive information and can be safely added to version control systems like git. +This means you have more flexibility with configurations. You can set up configurations directly in the code, +but it is strongly advised not to do this with secrets. + +:::caution +You can put all configurations and credentials in the `secrets.toml` if it's more convenient. +However, credentials cannot be placed in `configs.toml` because `dlt` doesn't look for them there. +::: + +Let's assume we have a [notion](../../dlt-ecosystem/verified-sources/notion) source and [filesystem](../../dlt-ecosystem/destinations/filesystem) destination: + + + + + +```toml +# we can set up a lot in config.toml +# config.toml +[runtime] +log_level="INFO" + +[destination.filesystem] +bucket_url = "s3://[your_bucket_name]" + +[normalize.data_writer] +disable_compression=true + +# but credentials should go to secrets.toml! +# secrets.toml +[source.notion] +api_key = "api_key" + +[destination.filesystem.credentials] +aws_access_key_id = "ABCDEFGHIJKLMNOPQRST" # copy the access key here +aws_secret_access_key = "1234567890_access_key" # copy the secret access key here +``` + + + + + +```sh +# ENV vars are set up the same way both for configs and secrets +export RUNTIME__LOG_LEVEL="INFO" +export DESTINATION__FILESYSTEM__BUCKET_URL="s3://[your_bucket_name]" +export NORMALIZE__DATA_WRITER__DISABLE_COMPRESSION="true" +export SOURCE__NOTION__API_KEY="api_key" +export DESTINATION__FILESYSTEM__CREDENTIALS__AWS_ACCESS_KEY_ID="api_key" +export DESTINATION__FILESYSTEM__CREDENTIALS__AWS_SECRET_ACCESS_KEY="api_key" +``` + + + + + +```py +import os +import dlt + +# you can freely set up configuration directly in the code + +# via env vars +os.environ["RUNTIME__LOG_LEVEL"] = "INFO" +os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "s3://[your_bucket_name]" +os.environ["NORMALIZE__DATA_WRITER__DISABLE_COMPRESSION"] = "true" + +# or even directly to the dlt.config +dlt.config["runtime.log_level"] = "INFO" +dlt.config["destination.filesystem.bucket_url"] = "INFO" +dlt.config["normalize.data_writer.disable_compression"] = "true" + +# but please, do not set up the secrets in the code! +# what you can do is reassign env variables: +os.environ["SOURCE__NOTION__API_KEY"] = os.environ.get("NOTION_KEY") + +# or use a third-party credentials supplier +import botocore.session + +credentials = AwsCredentials() +session = botocore.session.get_session() +credentials.parse_native_representation(session) +dlt.secrets["destination.filesystem.credentials"] = credentials +``` + + + + + +### Google credentials for both source and destination + +Let's assume we use the `bigquery` destination and the `google_sheets` source. They both use Google credentials and expect them to be configured under the `credentials` key. + +1. If we create just a single `credentials` section like in [here](#default-layout-without-sections), the destination and source will share the same credentials. + + + + + +```toml +[credentials] +client_email = "" +private_key = "" +project_id = "" +``` + + + + + +```sh +export CREDENTIALS__CLIENT_EMAIL="" +export CREDENTIALS__PRIVATE_KEY="" +export CREDENTIALS__PROJECT_ID="" +``` + + + + + +```py +import os + +# do not set up the secrets directly in the code! +# what you can do is reassign env variables +os.environ["CREDENTIALS__CLIENT_EMAIL"] = os.environ.get("GOOGLE_CLIENT_EMAIL") +os.environ["CREDENTIALS__PRIVATE_KEY"] = os.environ.get("GOOGLE_PRIVATE_KEY") +os.environ["CREDENTIALS__PROJECT_ID"] = os.environ.get("GOOGLE_PROJECT_ID") +``` + + + + + +2. If we define sections as below, we'll keep the credentials separate + + + + + +```toml +# google sheet credentials +[sources.credentials] +client_email = "" +private_key = "" +project_id = "" + +# bigquery credentials +[destination.credentials] +client_email = "" +private_key = "" +project_id = "" +``` + + + + + +```sh +# google sheet credentials +export SOURCES__CREDENTIALS__CLIENT_EMAIL="" +export SOURCES__CREDENTIALS__PRIVATE_KEY="" +export SOURCES__CREDENTIALS__PROJECT_ID="" + +# bigquery credentials +export DESTINATION__CREDENTIALS__CLIENT_EMAIL="" +export DESTINATION__CREDENTIALS__PRIVATE_KEY="" +export DESTINATION__CREDENTIALS__PROJECT_ID="" +``` + + + + + +```py +import dlt +import os + +# do not set up the secrets directly in the code! +# what you can do is reassign env variables +os.environ["DESTINATION__CREDENTIALS__CLIENT_EMAIL"] = os.environ.get("BIGQUERY_CLIENT_EMAIL") +os.environ["DESTINATION__CREDENTIALS__PRIVATE_KEY"] = os.environ.get("BIGQUERY_PRIVATE_KEY") +os.environ["DESTINATION__CREDENTIALS__PROJECT_ID"] = os.environ.get("BIGQUERY_PROJECT_ID") + +# or set them to the dlt.secrets +dlt.secrets["sources.credentials.client_email"] = os.environ.get("SHEETS_CLIENT_EMAIL") +dlt.secrets["sources.credentials.private_key"] = os.environ.get("SHEETS_PRIVATE_KEY") +dlt.secrets["sources.credentials.project_id"] = os.environ.get("SHEETS_PROJECT_ID") +``` + + + + + +Now `dlt` looks for destination credentials in the following order: +```sh +destination.bigquery.credentials --> Not found +destination.credentials --> Found +``` + +When looking for the source credentials: +```sh +sources.google_sheets_module.google_sheets_function.credentials --> Not found +sources.google_sheets_function.credentials --> Not found +sources.credentials --> Found +``` + +### Credentials for several different sources and destinations + +Let's assume we have several different Google sources and destinations. We can use full paths to organize the `secrets.toml` file: + + + + + +```toml +# google sheet credentials +[sources.google_sheets.credentials] +client_email = "" +private_key = "" +project_id = "" + +# google analytics credentials +[sources.google_analytics.credentials] +client_email = "" +private_key = "" +project_id = "" + +# bigquery credentials +[destination.bigquery.credentials] +client_email = "" +private_key = "" +project_id = "" +``` + + + + + +```sh +# google sheet credentials +export SOURCES__GOOGLE_SHEETS__CREDENTIALS__CLIENT_EMAIL="" +export SOURCES__GOOGLE_SHEETS__CREDENTIALS__PRIVATE_KEY="" +export SOURCES__GOOGLE_SHEETS__CREDENTIALS__PROJECT_ID="" + +# google analytics credentials +export SOURCES__GOOGLE_ANALYTICS__CREDENTIALS__CLIENT_EMAIL="" +export SOURCES__GOOGLE_ANALYTICS__CREDENTIALS__PRIVATE_KEY="" +export SOURCES__GOOGLE_ANALYTICS__CREDENTIALS__PROJECT_ID="" + +# bigquery credentials +export DESTINATION__BIGQUERY__CREDENTIALS__CLIENT_EMAIL="" +export DESTINATION__BIGQUERY__CREDENTIALS__PRIVATE_KEY="" +export DESTINATION__BIGQUERY__CREDENTIALS__PROJECT_ID="" +``` + + + + + +```py +import os +import dlt + +# do not set up the secrets directly in the code! +# what you can do is reassign env variables +os.environ["SOURCES__GOOGLE_ANALYTICS__CREDENTIALS__CLIENT_EMAIL"] = os.environ.get("SHEETS_CLIENT_EMAIL") +os.environ["SOURCES__GOOGLE_ANALYTICS__CREDENTIALS__PRIVATE_KEY"] = os.environ.get("ANALYTICS_PRIVATE_KEY") +os.environ["SOURCES__GOOGLE_ANALYTICS__CREDENTIALS__PROJECT_ID"] = os.environ.get("ANALYTICS_PROJECT_ID") + +os.environ["DESTINATION__CREDENTIALS__CLIENT_EMAIL"] = os.environ.get("BIGQUERY_CLIENT_EMAIL") +os.environ["DESTINATION__CREDENTIALS__PRIVATE_KEY"] = os.environ.get("BIGQUERY_PRIVATE_KEY") +os.environ["DESTINATION__CREDENTIALS__PROJECT_ID"] = os.environ.get("BIGQUERY_PROJECT_ID") + +# or set them to the dlt.secrets +dlt.secrets["sources.credentials.client_email"] = os.environ.get("SHEETS_CLIENT_EMAIL") +dlt.secrets["sources.credentials.private_key"] = os.environ.get("SHEETS_PRIVATE_KEY") +dlt.secrets["sources.credentials.project_id"] = os.environ.get("SHEETS_PROJECT_ID") +``` + + + + + +### Credentials for several sources of the same type + +Let's assume we have several sources of the same type, how can we separate them in the `secrets.toml`? The recommended solution is to use different pipeline names for each source: + + + + + +```toml +[pipeline_name_1.sources.sql_database] +credentials="snowflake://user1:password1@service-account/database1?warehouse=warehouse_name&role=role1" + +[pipeline_name_2.sources.sql_database] +credentials="snowflake://user2:password2@service-account/database2?warehouse=warehouse_name&role=role2" +``` + + + + + +```sh +export PIPELINE_NAME_1_SOURCES__SQL_DATABASE__CREDENTIALS="snowflake://user1:password1@service-account/database1?warehouse=warehouse_name&role=role1" +export PIPELINE_NAME_2_SOURCES__SQL_DATABASE__CREDENTIALS="snowflake://user2:password2@service-account/database2?warehouse=warehouse_name&role=role2" +``` + + + + + +```py +import os +import dlt + +# do not set up the secrets directly in the code! +# what you can do is reassign env variables +os.environ["PIPELINE_NAME_1_SOURCES__SQL_DATABASE__CREDENTIALS"] = os.environ.get("SQL_CREDENTIAL_STRING_1") + +# or set them to the dlt.secrets +dlt.secrets["pipeline_name_2.sources.sql_database.credentials"] = os.environ.get("SQL_CREDENTIAL_STRING_2") +``` + + + + + +## Understanding the exceptions + +If `dlt` expects configuration of secrets value but cannot find it, it will output the `ConfigFieldMissingException`. + +Let's run the `chess.py` example without providing the password: + +```sh +$ CREDENTIALS="postgres://loader@localhost:5432/dlt_data" python chess.py +... +dlt.common.configuration.exceptions.ConfigFieldMissingException: Following fields are missing: ['password'] in configuration with spec PostgresCredentials + for field "password" config providers and keys were tried in the following order: + In Environment Variables key CHESS_GAMES__DESTINATION__POSTGRES__CREDENTIALS__PASSWORD was not found. + In Environment Variables key CHESS_GAMES__DESTINATION__CREDENTIALS__PASSWORD was not found. + In Environment Variables key CHESS_GAMES__CREDENTIALS__PASSWORD was not found. + In secrets.toml key chess_games.destination.postgres.credentials.password was not found. + In secrets.toml key chess_games.destination.credentials.password was not found. + In secrets.toml key chess_games.credentials.password was not found. + In Environment Variables key DESTINATION__POSTGRES__CREDENTIALS__PASSWORD was not found. + In Environment Variables key DESTINATION__CREDENTIALS__PASSWORD was not found. + In Environment Variables key CREDENTIALS__PASSWORD was not found. + In secrets.toml key destination.postgres.credentials.password was not found. + In secrets.toml key destination.credentials.password was not found. + In secrets.toml key credentials.password was not found. +Please refer to https://dlthub.com/docs/general-usage/credentials for more information +``` + +It tells you exactly which paths `dlt` looked at, via which config providers and in which order. + +In the example above: + +1. First, `dlt` looked in a big section `chess_games`, which is the name of the pipeline. +1. In each case, it starts with full paths and goes to the minimum path `credentials.password`. +1. First, it looks into environment variables, then in `secrets.toml`. It displays the exact keys tried. +1. Note that `config.toml` was skipped! It could not contain any secrets. \ No newline at end of file diff --git a/docs/website/docs/general-usage/destination.md b/docs/website/docs/general-usage/destination.md index b30403d349..d88a0b53f2 100644 --- a/docs/website/docs/general-usage/destination.md +++ b/docs/website/docs/general-usage/destination.md @@ -28,7 +28,7 @@ Above we use built in **filesystem** destination by providing a factory type `fi Above we import destination factory for **filesystem** and pass it to the pipeline. -All examples above will create the same destination factory with default parameters and pull required config and secret values from [configuration](credentials/configuration.md) - they are equivalent. +All examples above will create the same destination class with default parameters and pull required config and secret values from [configuration](credentials/index.md) - they are equivalent. ### Pass explicit parameters and a name to a destination @@ -41,9 +41,9 @@ If destination is not named, its shorthand type (the Python factory name) serves ## Configure a destination -We recommend to pass the credentials and other required parameters to configuration via TOML files, environment variables or other [config providers](credentials/config_providers.md). This allows you, for example, to easily switch to production destinations after deployment. +We recommend to pass the credentials and other required parameters to configuration via TOML files, environment variables or other [config providers](credentials/setup). This allows you, for example, to easily switch to production destinations after deployment. -We recommend to use the [default config section layout](credentials/configuration.md#default-layout-and-default-key-lookup-during-injection) as below: +We recommend to use the [default config section layout](credentials/setup#structure-of-secrets.toml-and-config.toml) as below: or via environment variables: @@ -65,14 +65,14 @@ You can pass credentials explicitly when creating destination factory instance. :::tip -You can create and pass partial credentials and `dlt` will fill the missing data. Below we pass postgres connection string but without password and expect that it will be present in environment variables (or any other [config provider](credentials/config_providers.md)) +You can create and pass partial credentials and `dlt` will fill the missing data. Below we pass postgres connection string but without password and expect that it will be present in environment variables (or any other [config provider](credentials/setup)) -Please read how to use [various built in credentials types](credentials/config_specs.md). +Please read how to use [various built in credentials types](credentials/complex_types). ::: ### Inspect destination capabilities diff --git a/docs/website/docs/general-usage/glossary.md b/docs/website/docs/general-usage/glossary.md index adbb30f108..5ae256b268 100644 --- a/docs/website/docs/general-usage/glossary.md +++ b/docs/website/docs/general-usage/glossary.md @@ -53,11 +53,11 @@ Describes the structure of normalized data (e.g. unpacked tables, column types, instructions on how the data should be processed and loaded (i.e. it tells `dlt` about the content of the data and how to load it into the destination). -## [Config](credentials/configuration) +## [Config](credentials/setup#secrets.toml-and-config.toml) A set of values that are passed to the pipeline at run time (e.g. to change its behavior locally vs. in production). -## [Credentials](credentials/config_specs) +## [Credentials](credentials/complex_types) A subset of configuration whose elements are kept secret and never shared in plain text. diff --git a/docs/website/docs/general-usage/http/rest-client.md b/docs/website/docs/general-usage/http/rest-client.md index ddd66a233b..40c83f8c5b 100644 --- a/docs/website/docs/general-usage/http/rest-client.md +++ b/docs/website/docs/general-usage/http/rest-client.md @@ -183,8 +183,9 @@ need to specify the paginator when the API uses a different relation type. - `offset`: The initial offset for the first request. Defaults to `0`. - `offset_param`: The name of the query parameter used to specify the offset. Defaults to `"offset"`. - `limit_param`: The name of the query parameter used to specify the limit. Defaults to `"limit"`. -- `total_path`: A JSONPath expression for the total number of items. If not provided, pagination is controlled by `maximum_offset`. +- `total_path`: A JSONPath expression for the total number of items. If not provided, pagination is controlled by `maximum_offset` and `stop_after_empty_page`. - `maximum_offset`: Optional maximum offset value. Limits pagination even without total count. +- `stop_after_empty_page`: Whether pagination should stop when a page contains no result items. Defaults to `True`. **Example:** @@ -198,7 +199,7 @@ E.g. `https://api.example.com/items?offset=0&limit=100`, `https://api.example.co } ``` -You can paginate through responses from this API using `OffsetPaginator`: +You can paginate through responses from this API using the `OffsetPaginator`: ```py client = RESTClient( @@ -210,20 +211,34 @@ client = RESTClient( ) ``` -In a different scenario where the API does not provide the total count, you can use `maximum_offset` to limit the pagination: +Pagination stops by default when a page contains no records. This is especially useful when the API does not provide the total item count. +Here, the `total_path` parameter is set to `None` because the API does not provide the total count. ```py client = RESTClient( base_url="https://api.example.com", paginator=OffsetPaginator( limit=100, - maximum_offset=1000, - total_path=None + total_path=None, ) ) ``` -Note, that in this case, the `total_path` parameter is set explicitly to `None` to indicate that the API does not provide the total count. +Additionally, you can limit pagination with `maximum_offset`, for example during development. If `maximum_offset` is reached before the first empty page then pagination stops: + +```py +client = RESTClient( + base_url="https://api.example.com", + paginator=OffsetPaginator( + limit=10, + maximum_offset=20, # limits response to 20 records + total_path=None, + ) +) +``` + +You can disable automatic stoppage of pagination by setting `stop_after_stop_after_empty_page = False`. In this case, you must provide either `total_path` or `maximum_offset` to guarantee that the paginator terminates. + #### PageNumberPaginator @@ -234,8 +249,9 @@ Note, that in this case, the `total_path` parameter is set explicitly to `None` - `base_page`: The index of the initial page from the API perspective. Normally, it's 0-based or 1-based (e.g., 1, 2, 3, ...) indexing for the pages. Defaults to 0. - `page`: The page number for the first request. If not provided, the initial value will be set to `base_page`. - `page_param`: The query parameter name for the page number. Defaults to `"page"`. -- `total_path`: A JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page`. +- `total_path`: A JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page` and `stop_after_empty_page`. - `maximum_page`: Optional maximum page number. Stops pagination once this page is reached. +- `stop_after_empty_page`: Whether pagination should stop when a page contains no result items. Defaults to `True`. **Example:** @@ -248,7 +264,7 @@ Assuming an API endpoint `https://api.example.com/items` paginates by page numbe } ``` -You can paginate through responses from this API using `PageNumberPaginator`: +You can paginate through responses from this API using the `PageNumberPaginator`: ```py client = RESTClient( @@ -259,19 +275,32 @@ client = RESTClient( ) ``` -If the API does not provide the total number of pages: +Pagination stops by default when a page contains no records. This is especially useful when the API does not provide the total item count. +Here, the `total_path` parameter is set to `None` because the API does not provide the total count. ```py client = RESTClient( base_url="https://api.example.com", paginator=PageNumberPaginator( - maximum_page=5, # Stops after fetching 5 pages total_path=None ) ) ``` -Note, that in the case above, the `total_path` parameter is set explicitly to `None` to indicate that the API does not provide the total count. +Additionally, you can limit pagination with `maximum_offset`, for example during development. If `maximum_page` is reached before the first empty page then pagination stops: + +```py +client = RESTClient( + base_url="https://api.example.com", + paginator=OffsetPaginator( + maximum_page=2, # limits response to 2 pages + total_path=None, + ) +) +``` + +You can disable automatic stoppage of pagination by setting `stop_after_stop_after_empty_page = False`. In this case, you must provide either `total_path` or `maximum_page` to guarantee that the paginator terminates. + #### JSONResponseCursorPaginator @@ -310,7 +339,7 @@ When working with APIs that use non-standard pagination schemes, or when you nee - `init_request(request: Request) -> None`: This method is called before making the first API call in the `RESTClient.paginate` method. You can use this method to set up the initial request query parameters, headers, etc. For example, you can set the initial page number or cursor value. -- `update_state(response: Response) -> None`: This method updates the paginator's state based on the response of the API call. Typically, you extract pagination details (like the next page reference) from the response and store them in the paginator instance. +- `update_state(response: Response, data: Optional[List[Any]]) -> None`: This method updates the paginator's state based on the response of the API call. Typically, you extract pagination details (like the next page reference) from the response and store them in the paginator instance. - `update_request(request: Request) -> None`: Before making the next API call in `RESTClient.paginate` method, `update_request` is used to modify the request with the necessary parameters to fetch the next page (based on the current state of the paginator). For example, you can add query parameters to the request, or modify the URL. @@ -319,6 +348,7 @@ When working with APIs that use non-standard pagination schemes, or when you nee Suppose an API uses query parameters for pagination, incrementing an page parameter for each subsequent page, without providing direct links to next pages in its responses. E.g. `https://api.example.com/posts?page=1`, `https://api.example.com/posts?page=2`, etc. Here's how you could implement a paginator for this scheme: ```py +from typing import Any, List, Optional from dlt.sources.helpers.rest_client.paginators import BasePaginator from dlt.sources.helpers.requests import Response, Request @@ -332,7 +362,7 @@ class QueryParamPaginator(BasePaginator): # This will set the initial page number (e.g. page=1) self.update_request(request) - def update_state(self, response: Response) -> None: + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: # Assuming the API returns an empty list when no more data is available if not response.json(): self._has_next_page = False @@ -370,6 +400,7 @@ def get_data(): Some APIs use POST requests for pagination, where the next page is fetched by sending a POST request with a cursor or other parameters in the request body. This is frequently used in "search" API endpoints or other endpoints with big payloads. Here's how you could implement a paginator for a case like this: ```py +from typing import Any, List, Optional from dlt.sources.helpers.rest_client.paginators import BasePaginator from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.requests import Response, Request @@ -379,7 +410,7 @@ class PostBodyPaginator(BasePaginator): super().__init__() self.cursor = None - def update_state(self, response: Response) -> None: + def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: # Assuming the API returns an empty list when no more data is available if not response.json(): self._has_next_page = False diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index b130f7a4f5..68fc46e6dc 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -251,6 +251,19 @@ executed. You can achieve the same in the decorator `@dlt.source(root_key=True)` ### `scd2` strategy `dlt` can create [Slowly Changing Dimension Type 2](https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row) (SCD2) destination tables for dimension tables that change in the source. The resource is expected to provide a full extract of the source table each run. A row hash is stored in `_dlt_id` and used as surrogate key to identify source records that have been inserted, updated, or deleted. A `NULL` value is used by default to indicate an active record, but it's possible to use a configurable high timestamp (e.g. 9999-12-31 00:00:00.000000) instead. +:::note +The `unique` hint for `_dlt_id` in the root table is set to `false` when using `scd2`. This differs from [default behavior](./destination-tables.md#child-and-parent-tables). The reason is that the surrogate key stored in `_dlt_id` contains duplicates after an _insert-delete-reinsert_ pattern: +1. record with surrogate key X is inserted in a load at `t1` +2. record with surrogate key X is deleted in a later load at `t2` +3. record with surrogate key X is reinserted in an even later load at `t3` + +After this pattern, the `scd2` table in the destination has two records for surrogate key X: one for validity window `[t1, t2]`, and one for `[t3, NULL]`. A duplicate value exists in `_dlt_id` because both records have the same surrogate key. + +Note that: +- the composite key `(_dlt_id, _dlt_valid_from)` is unique +- `_dlt_id` remains unique for child tables—`scd2` does not affect this +::: + #### Example: `scd2` merge strategy ```py @dlt.resource( @@ -335,7 +348,23 @@ You can configure the literal used to indicate an active record with `active_rec write_disposition={ "disposition": "merge", "strategy": "scd2", - "active_record_timestamp": "9999-12-31", # e.g. datetime.datetime(9999, 12, 31) is also accepted + # accepts various types of date/datetime objects + "active_record_timestamp": "9999-12-31", + } +) +def dim_customer(): + ... +``` + +#### Example: configure boundary timestamp +You can configure the "boundary timestamp" used for record validity windows with `boundary_timestamp`. The provided date(time) value is used as "valid from" for new records and as "valid to" for retired records. The timestamp at which a load package is created is used if `boundary_timestamp` is omitted. +```py +@dlt.resource( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + # accepts various types of date/datetime objects + "boundary_timestamp": "2024-08-21T12:15:00+00:00", } ) def dim_customer(): diff --git a/docs/website/docs/general-usage/naming-convention.md b/docs/website/docs/general-usage/naming-convention.md index bf6e650b9c..11032a4457 100644 --- a/docs/website/docs/general-usage/naming-convention.md +++ b/docs/website/docs/general-usage/naming-convention.md @@ -5,35 +5,34 @@ keywords: [identifiers, snake case, case sensitive, case insensitive, naming] --- # Naming Convention -`dlt` creates table and column identifiers from the data. The data source ie. a stream of JSON documents may have identifiers (i.e. key names in a dictionary) with any Unicode characters, of any length and naming style. On the other hand, destinations require that you follow strict rules when you name tables, columns or collections. -A good example is [Redshift](../dlt-ecosystem/destinations/redshift.md#naming-convention) that accepts case-insensitive alphanumeric identifiers with maximum 127 characters. +`dlt` creates table and column identifiers from the data. The data source, i.e. a stream of JSON documents, may have identifiers (i.e. key names in a dictionary) with any Unicode characters, of any length, and naming style. On the other hand, destinations require that you follow strict rules when you name tables, columns, or collections. +A good example is [Redshift](../dlt-ecosystem/destinations/redshift.md#naming-convention) that accepts case-insensitive alphanumeric identifiers with a maximum of 127 characters. -`dlt` groups tables from a single [source](source.md) in a [schema](schema.md). Each schema defines **naming convention** that tells `dlt` how to translate identifiers to the -namespace that the destination understands. Naming conventions are in essence functions that map strings from the source identifier format into destination identifier format. For example our **snake_case** (default) naming convention will translate `DealFlow` source identifier into `deal_flow` destination identifier. +`dlt` groups tables from a single [source](source.md) in a [schema](schema.md). Each schema defines a **naming convention** that tells `dlt` how to translate identifiers to the +namespace that the destination understands. Naming conventions are, in essence, functions that map strings from the source identifier format into the destination identifier format. For example, our **snake_case** (default) naming convention will translate the `DealFlow` source identifier into the `deal_flow` destination identifier. You can pick which naming convention to use. `dlt` provides a few to [choose from](#available-naming-conventions). You can [easily add your own](#write-your-own-naming-convention) as well. :::tip -Standard behavior of `dlt` is to **use the same naming convention for all destinations** so users see always the same table and column names in their databases. +The standard behavior of `dlt` is to **use the same naming convention for all destinations** so users always see the same table and column names in their databases. ::: ### Use default naming convention (snake_case) -**snake_case** is case insensitive naming convention, converting source identifiers into lower case snake case identifiers with reduced alphabet. +**snake_case** is a case-insensitive naming convention, converting source identifiers into lower-case snake case identifiers with a reduced alphabet. -- Spaces around identifier are trimmed -- Keeps ascii alphanumerics and underscores, replaces all other characters with underscores (with the exceptions below) -- Replaces `+` and `*` with `x`, `-` with `_`, `@` with `a` and `|` with `l` -- Prepends `_` if name starts with number. -- Multiples of `_` are converted into single `_`. +- Spaces around identifiers are trimmed +- Keeps ASCII alphanumerics and underscores, replaces all other characters with underscores (with the exceptions below) +- Replaces `+` and `*` with `x`, `-` with `_`, `@` with `a`, and `|` with `l` +- Prepends `_` if the name starts with a number. +- Multiples of `_` are converted into a single `_`. - Replaces all trailing `_` with `x` -Uses __ as patent-child separator for tables and flattened column names. +Uses __ as a parent-child separator for tables and flattened column names. :::tip -If you do not like **snake_case** your next safe option is **sql_ci** which generates SQL-safe, lower-case, case-insensitive identifiers without any -other transformations. To permanently change the default naming convention on a given machine: +If you do not like **snake_case**, your next safe option is **sql_ci**, which generates SQL-safe, lowercase, case-insensitive identifiers without any other transformations. To permanently change the default naming convention on a given machine: 1. set an environment variable `SCHEMA__NAMING` to `sql_ci_v1` OR -2. add the following line to your global `config.toml` (the one in your home dir ie. `~/.dlt/config.toml`) +2. add the following line to your global `config.toml` (the one in your home dir, i.e. `~/.dlt/config.toml`) ```toml [schema] naming="sql_ci_v1" @@ -42,53 +41,50 @@ naming="sql_ci_v1" ## Source identifiers vs destination identifiers ### Pick the right identifier form when defining resources -`dlt` keeps source (not normalized) identifiers during data [extraction](../reference/explainers/how-dlt-works.md#extract) and translates them during [normalization](../reference/explainers/how-dlt-works.md#normalize). For you it means: +`dlt` keeps source (not normalized) identifiers during data [extraction](../reference/explainers/how-dlt-works.md#extract) and translates them during [normalization](../reference/explainers/how-dlt-works.md#normalize). For you, it means: 1. If you write a [transformer](resource.md#process-resources-with-dlttransformer) or a [mapping/filtering function](resource.md#filter-transform-and-pivot-data), you will see the original data, without any normalization. Use the source identifiers to access the dicts! -2. If you define a `primary_key` or `cursor` that participate in [cursor field incremental loading](incremental-loading.md#incremental-loading-with-a-cursor-field) use the source identifiers (`dlt` uses them to inspect source data, `Incremental` class is just a filtering function). -3. When defining any other hints ie. `columns` or `merge_key` you can pick source or destination identifiers. `dlt` normalizes all hints together with your data. -4. `Schema` object (ie. obtained from the pipeline or from `dlt` source via `discover_schema`) **always contains destination (normalized) identifiers**. +2. If you define a `primary_key` or `cursor` that participate in [cursor field incremental loading](incremental-loading.md#incremental-loading-with-a-cursor-field), use the source identifiers (`dlt` uses them to inspect source data, `Incremental` class is just a filtering function). +3. When defining any other hints, i.e. `columns` or `merge_key`, you can pick source or destination identifiers. `dlt` normalizes all hints together with your data. +4. The `Schema` object (i.e. obtained from the pipeline or from `dlt` source via `discover_schema`) **always contains destination (normalized) identifiers**. ### Understand the identifier normalization -Identifiers are translated from source to destination form in **normalize** step. Here's how `dlt` picks the naming convention: +Identifiers are translated from source to destination form in the **normalize** step. Here's how `dlt` picks the naming convention: * The default naming convention is **snake_case**. -* Each destination may define a preferred naming convention in [destination capabilities](destination.md#pass-additional-parameters-and-change-destination-capabilities). Some destinations (ie. Weaviate) need specialized naming convention and will override the default. +* Each destination may define a preferred naming convention in [destination capabilities](destination.md#pass-additional-parameters-and-change-destination-capabilities). Some destinations (i.e. Weaviate) need a specialized naming convention and will override the default. * You can [configure a naming convention explicitly](#set-and-adjust-naming-convention-explicitly). Such configuration overrides the destination settings. -* This naming convention is used when new schemas are created. It happens when pipeline is run for the first time. -* Schemas preserve naming convention when saved. Your running pipelines will maintain existing naming conventions if not requested otherwise. -* `dlt` applies final naming convention in `normalize` step. Jobs (files) in load package now have destination identifiers. Pipeline schema is duplicated, locked and saved in the load package and will be used by the destination. +* This naming convention is used when new schemas are created. It happens when the pipeline is run for the first time. +* Schemas preserve the naming convention when saved. Your running pipelines will maintain existing naming conventions if not requested otherwise. +* `dlt` applies the final naming convention in the `normalize` step. Jobs (files) in the load package now have destination identifiers. The pipeline schema is duplicated, locked, and saved in the load package and will be used by the destination. :::caution -If you change naming convention and `dlt` detects that a change in the destination identifiers for tables/collection/files that already exist and store data, -the normalize process will fail. This prevents an unwanted schema migration. New columns and tables will be created for identifiers that changed. +If you change the naming convention and `dlt` detects a change in the destination identifiers for tables/collections/files that already exist and store data, the normalize process will fail. This prevents an unwanted schema migration. New columns and tables will be created for identifiers that changed. ::: -### Case sensitive and insensitive destinations -Naming convention declare if the destination identifiers they produce are case sensitive or insensitive. This helps `dlt` to [generate case sensitive / insensitive identifiers for the destinations that support both](destination.md#control-how-dlt-creates-table-column-and-other-identifiers). For example: if you pick case insensitive naming like **snake_case** or **sql_ci_v1**, with Snowflake, `dlt` will generate all upper-case identifiers that Snowflake sees as case insensitive. If you pick case sensitive naming like **sql_cs_v1**, `dlt` will generate quoted case-sensitive identifiers that preserve identifier capitalization. +### Case-sensitive and insensitive destinations +Naming conventions declare if the destination identifiers they produce are case-sensitive or insensitive. This helps `dlt` to [generate case-sensitive / insensitive identifiers for the destinations that support both](destination.md#control-how-dlt-creates-table-column-and-other-identifiers). For example: if you pick a case-insensitive naming like **snake_case** or **sql_ci_v1**, with Snowflake, `dlt` will generate all uppercase identifiers that Snowflake sees as case-insensitive. If you pick a case-sensitive naming like **sql_cs_v1**, `dlt` will generate quoted case-sensitive identifiers that preserve identifier capitalization. -Note that many destinations are exclusively case insensitive, of which some preserve casing of identifiers (ie. **duckdb**) and some will case-fold identifiers when creating tables (ie. **Redshift**, **Athena** do lower case on the names). `dlt` is able to detect resulting identifier [collisions](#avoid-identifier-collisions) and stop the load process before data is mangled. +Note that many destinations are exclusively case-insensitive, of which some preserve the casing of identifiers (i.e. **duckdb**) and some will case-fold identifiers when creating tables (i.e. **Redshift**, **Athena** do lowercase on the names). `dlt` is able to detect resulting identifier [collisions](#avoid-identifier-collisions) and stop the load process before data is mangled. ### Identifier shortening -Identifier shortening happens during normalization. `dlt` takes the maximum length of the identifier from the destination capabilities and will trim the identifiers that are -too long. The default shortening behavior generates short deterministic hashes of the source identifiers and places them in the middle of the destination identifier. This -(with a high probability) avoids shortened identifier collisions. +Identifier shortening happens during normalization. `dlt` takes the maximum length of the identifier from the destination capabilities and will trim the identifiers that are too long. The default shortening behavior generates short deterministic hashes of the source identifiers and places them in the middle of the destination identifier. This (with a high probability) avoids shortened identifier collisions. ### 🚧 [WIP] Name convention changes are lossy -`dlt` does not store the source identifiers in the schema so when naming convention changes (or we increase the maximum identifier length), it is not able to generate a fully correct set of new identifiers. Instead it will re-normalize already normalized identifiers. We are currently working to store full identifier lineage - source identifiers will be stored and mapped to the destination in the schema. +`dlt` does not store the source identifiers in the schema so when the naming convention changes (or we increase the maximum identifier length), it is not able to generate a fully correct set of new identifiers. Instead, it will re-normalize already normalized identifiers. We are currently working to store the full identifier lineage - source identifiers will be stored and mapped to the destination in the schema. ## Pick your own naming convention ### Configure naming convention -You can use `config.toml`, environment variables or any other configuration provider to set the naming convention name. Configured naming convention **overrides all other settings** +You can use `config.toml`, environment variables, or any other configuration provider to set the naming convention name. Configured naming convention **overrides all other settings** - changes the naming convention stored in the already created schema - overrides the destination capabilities preference. ```toml [schema] naming="sql_ci_v1" ``` -Configuration above will request **sql_ci_v1** for all pipelines (schemas). An environment variable `SCHEMA__NAMING` set to `sql_ci_v1` has the same effect. +The configuration above will request **sql_ci_v1** for all pipelines (schemas). An environment variable `SCHEMA__NAMING` set to `sql_ci_v1` has the same effect. -You have an option to set naming convention per source: +You have an option to set the naming convention per source: ```toml [sources.zendesk] config="prop" @@ -97,35 +93,57 @@ naming="sql_cs_v1" [sources.zendesk.credentials] password="pass" ``` -Snippet above demonstrates how to apply certain naming for an example `zendesk` source. +The snippet above demonstrates how to apply certain naming for an example `zendesk` source. -You can use naming conventions that you created yourself or got from other users. In that case you should pass a full Python import path to the [module that contain the naming convention](#write-your-own-naming-convention): +You can use naming conventions that you created yourself or got from other users. In that case, you should pass a full Python import path to the [module that contains the naming convention](#write-your-own-naming-convention): ```toml [schema] naming="tests.common.cases.normalizers.sql_upper" ``` -`dlt` will import `tests.common.cases.normalizers.sql_upper` and use `NamingConvention` class found in it as the naming convention. +`dlt` will import `tests.common.cases.normalizers.sql_upper` and use the `NamingConvention` class found in it as the naming convention. ### Available naming conventions You can pick from a few built-in naming conventions. * `snake_case` - the default -* `duck_case` - case sensitive, allows all unicode characters like emoji 💥 -* `direct` - case sensitive, allows all unicode characters, does not contract underscores -* `sql_cs_v1` - case sensitive, generates sql-safe identifiers -* `sql_ci_v1` - case insensitive, generates sql-safe lower case identifiers +* `duck_case` - case-sensitive, allows all Unicode characters like emoji 💥 +* `direct` - case-sensitive, allows all Unicode characters, does not contract underscores +* `sql_cs_v1` - case-sensitive, generates SQL-safe identifiers +* `sql_ci_v1` - case-insensitive, generates SQL-safe lowercase identifiers -### Set and adjust naming convention explicitly -You can modify destination capabilities to + +### Ignore naming convention for `dataset_name` +You control the dataset naming normalization separately. Set `enable_dataset_name_normalization` to `false` to ignore the naming convention for `dataset_name`: + +```toml +[destination.snowflake] +enable_dataset_name_normalization=false +``` + +In that case, the `dataset_name` would be preserved the same as it was set in the pipeline: +```py +import dlt + +pipeline = dlt.pipeline(dataset_name="MyCamelCaseName") +``` + +The default value for the `enable_dataset_name_normalization` configuration option is `true`. +:::note +The same setting would be applied to [staging dataset](../dlt-ecosystem/staging#staging-dataset). Thus, if you set `enable_dataset_name_normalization` to `false`, the staging dataset name would also **not** be normalized. +::: + +:::caution +Depending on the destination, certain names may not be allowed. To ensure your dataset can be successfully created, use the default normalization option. +::: ## Avoid identifier collisions `dlt` detects various types of identifier collisions and ignores the others. -1. `dlt` detects collisions if case sensitive naming convention is used on case insensitive destination -2. `dlt` detects collisions if change of naming convention changes the identifiers of tables already created in the destination -3. `dlt` detects collisions when naming convention is applied to column names of arrow tables +1. `dlt` detects collisions if a case-sensitive naming convention is used on a case-insensitive destination +2. `dlt` detects collisions if a change of naming convention changes the identifiers of tables already created in the destination +3. `dlt` detects collisions when the naming convention is applied to column names of arrow tables -`dlt` will not detect collision when normalizing source data. If you have a dictionary, keys will be merged if they collide after being normalized. +`dlt` will not detect a collision when normalizing source data. If you have a dictionary, keys will be merged if they collide after being normalized. You can create a custom naming convention that does not generate collisions on data, see examples below. @@ -134,14 +152,14 @@ Custom naming conventions are classes that derive from `NamingConvention` that y 1. Each naming convention resides in a separate Python module (file) 2. The class is always named `NamingConvention` -In that case you can use a fully qualified module name in [schema configuration](#configure-naming-convention) or pass module [explicitly](#set-and-adjust-naming-convention-explicitly). +In that case, you can use a fully qualified module name in [schema configuration](#configure-naming-convention) or pass the module [explicitly](#set-and-adjust-naming-convention-explicitly). We include [two examples](../examples/custom_naming) of naming conventions that you may find useful: -1. A variant of `sql_ci` that generates identifier collisions with a low (user defined) probability by appending a deterministic tag to each name. -2. A variant of `sql_cs` that allows for LATIN (ie. umlaut) characters +1. A variant of `sql_ci` that generates identifier collisions with a low (user-defined) probability by appending a deterministic tag to each name. +2. A variant of `sql_cs` that allows for LATIN (i.e. umlaut) characters :::note -Note that a fully qualified name of your custom naming convention will be stored in the `Schema` and `dlt` will attempt to import it when schema is loaded from storage. +Note that a fully qualified name of your custom naming convention will be stored in the `Schema` and `dlt` will attempt to import it when the schema is loaded from storage. You should distribute your custom naming conventions with your pipeline code or via a pip package from which it can be imported. -::: +::: \ No newline at end of file diff --git a/docs/website/docs/general-usage/schema.md b/docs/website/docs/general-usage/schema.md index 0e3e3bba1f..df405de1af 100644 --- a/docs/website/docs/general-usage/schema.md +++ b/docs/website/docs/general-usage/schema.md @@ -352,8 +352,30 @@ load_info = pipeline.run(source_data) ``` This example iterates through MongoDB collections, applying the complex [data type](schema#data-types) to a specified column, and then processes the data with `pipeline.run`. -## Export and import schema files +## View and print the schema +To view and print the default schema in a clear YAML format use the command: + +```py +pipeline.default_schema.to_pretty_yaml() +``` +This can be used in a pipeline as: +```py +# Create a pipeline +pipeline = dlt.pipeline( + pipeline_name="chess_pipeline", + destination='duckdb', + dataset_name="games_data") + +# Run the pipeline +load_info = pipeline.run(source) + +# Print the default schema in a pretty YAML format +print(pipeline.default_schema.to_pretty_yaml()) +``` +This will display a structured YAML representation of your schema, showing details like tables, columns, data types, and metadata, including version, version_hash, and engine_version. + +## Export and import schema files Please follow the guide on [how to adjust a schema](../walkthroughs/adjust-a-schema.md) to export and import `yaml` schema files in your pipeline. diff --git a/docs/website/docs/running-in-production/running.md b/docs/website/docs/running-in-production/running.md index 377cf57f2c..3b5762612c 100644 --- a/docs/website/docs/running-in-production/running.md +++ b/docs/website/docs/running-in-production/running.md @@ -119,7 +119,7 @@ truncate_staging_dataset=true ## Using slack to send messages `dlt` provides basic support for sending slack messages. You can configure Slack incoming hook via -[secrets.toml or environment variables](../general-usage/credentials/config_providers). Please note that **Slack +[secrets.toml or environment variables](../general-usage/credentials/setup). Please note that **Slack incoming hook is considered a secret and will be immediately blocked when pushed to github repository**. In `secrets.toml`: diff --git a/docs/website/docs/walkthroughs/add-incremental-configuration.md b/docs/website/docs/walkthroughs/add-incremental-configuration.md new file mode 100644 index 0000000000..ab7142695f --- /dev/null +++ b/docs/website/docs/walkthroughs/add-incremental-configuration.md @@ -0,0 +1,310 @@ +--- +title: Add incremental configuration to SQL resources +description: Incremental SQL data loading strategies +keywords: [how to, load data incrementally from SQL] +slug: sql-incremental-configuration +--- + +# Add incremental configuration to SQL resources +Incremental loading is the act of loading only new or changed data and not old records that have already been loaded. +For example, a bank loading only the latest transactions or a company updating its database with new or modified user +information. In this article, we’ll discuss a few incremental loading strategies. + +:::important +Processing data incrementally, or in batches, enhances efficiency, reduces costs, lowers latency, improves scalability, + and optimizes resource utilization. +::: + +### Incremental loading strategies + +In this guide, we will discuss various incremental loading methods using `dlt`, specifically: + +| S.No. | Strategy | Description | +| --- | --- | --- | +| 1. | Full load (replace) | It completely overwrites the existing data with the new/updated dataset. | +| 2. | Append new records based on Incremental ID | Appends only new records to the table based on an incremental ID. | +| 3. | Append new records based on date ("created_at") | Appends only new records to the table based on a date field. | +| 4. | Merge (Update/Insert) records based on timestamp ("last_modified_at") and ID | Merges records based on a composite ID key and a timestamp field. Updates existing records and inserts new ones as necessary. | + +## Code examples + +### 1. Full load (replace) + +A full load strategy completely overwrites the existing data with the new dataset. This is useful when you want to +refresh the entire table with the latest data. + +:::note +This strategy technically does not load only new data but instead reloads all data: old and new. +::: + +Here’s a walkthrough: + +1. The initial table, named "contact", in the SQL source looks like this: + + | id | name | created_at | + | --- | --- | --- | + | 1 | Alice | 2024-07-01 | + | 2 | Bob | 2024-07-02 | + +2. The python code illustrates the process of loading data from an SQL source into BigQuery using the `dlt` pipeline. +Please note the `write_disposition = "replace”` used below. + + ```py + def load_full_table_resource() -> None: + """Load a full table, replacing existing data.""" + pipeline = dlt.pipeline( + pipeline_name="mysql_database", + destination='bigquery', + dataset_name="dlt_contacts" + ) + + # Load the full table "contact" + source = sql_database().with_resources("contact") + + # Run the pipeline + info = pipeline.run(source, write_disposition="replace") + + # Print the info + print(info) + + load_full_table_resource() + ``` + +3. After running the `dlt` pipeline, the data loaded into the BigQuery "contact" table looks like: + + | Row | id | name | created_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 1 | Alice | 2024-07-01 | 1721878309.021546 | tgyMM73iMz0cQg | + | 2 | 2 | Bob | 2024-07-02 | 1721878309.021546 | 88P0bD796pXo/Q | + +4. Next, the "contact" table in the SQL source is updated—two new rows are added, and the row with `id = 2` is removed. +The updated data source ("contact" table) now presents itself as follows: + + | id | name | created_at | + | --- | --- | --- | + | 1 | Alice | 2024-07-01 | + | 3 | Charlie | 2024-07-03 | + | 4 | Dave | 2024-07-04 | + +5. The "contact" table created in BigQuery after running the pipeline again: + + | Row | id | name | created_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 1 | Alice | 2024-07-01 | 1721878309.021546 | S5ye6fMhYECZA | + | 2 | 3 | Charlie | 2024-07-03 | 1721878309.021546 | eT0zheRx9ONWuQ | + | 3 | 4 | Dave | 2024-07-04 | 1721878309.021546 | gtflF8BdL2NO/Q | + +**What happened?** + +After running the pipeline, the original data in the "contact" table (Alice and Bob) is completely replaced with the new +updated table with data “Charlie” and “Dave” added and “Bob” removed. This strategy is useful for scenarios where the entire +dataset needs to be refreshed/replaced with the latest information. + +### 2. Append new records based on incremental ID + +This strategy appends only new records to the table based on an incremental ID. It is useful for scenarios where each new record has a unique, incrementing identifier. + +Here’s a walkthrough: + +1. The initial table, named "contact", in the SQL source looks like this: + + | id | name | created_at | + | --- | --- | --- | + | 1 | Alice | 2024-07-01 | + | 2 | Bob | 2024-07-02 | + +2. The python code demonstrates loading data from an SQL source into BigQuery using an incremental variable, `id`. +This variable tracks new or updated records in the `dlt` pipeline. Please note the `write_disposition = "append”` +used below. + + ```py + def load_incremental_id_table_resource() -> None: + """Load a table incrementally based on an ID.""" + pipeline = dlt.pipeline( + pipeline_name="mysql_database", + destination='bigquery', + dataset_name="dlt_contacts", + ) + + # Load table "contact" incrementally based on ID + source = sql_database().with_resources("contact") + source.contact.apply_hints(incremental=dlt.sources.incremental("id")) + + # Run the pipeline with append write disposition + info = pipeline.run(source, write_disposition="append") + + # Print the info + print(info) + ``` + +3. After running the `dlt` pipeline, the data loaded into BigQuery "contact" table looks like: + + | Row | id | name | created_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 1 | Alice | 2024-07-01 | 1721878309.021546 | YQfmAu8xysqWmA | + | 2 | 2 | Bob | 2024-07-02 | 1721878309.021546 | Vcb5KKah/RpmQw | + +4. Next, the "contact" table in the SQL source is updated—two new rows are added, and the row with `id = 2` is removed. +The updated data source now presents itself as follows: + + | id | name | created_at | + | --- | --- | --- | + | 1 | Alice | 2024-07-01 | + | 3 | Charlie | 2024-07-03 | + | 4 | Dave | 2024-07-04 | + +5. The "contact" table created in BigQuery after running the pipeline again: + + | Row | id | name | created_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 1 | Alice | 2024-07-01 | 1721878309.021546 | OW9ZyAzkXg4D4w | + | 2 | 2 | Bob | 2024-07-02 | 1721878309.021546 | skVYZ/ppQuztUg | + | 3 | 3 | Charlie | 2024-07-03 | 1721878309.021546 | y+T4Q2JDnR33jg | + | 4 | 4 | Dave | 2024-07-04 | 1721878309.021546 | MAXrGhNNADXAiQ | + +**What happened?** + +In this scenario, the pipeline appends new records (Charlie and Dave) to the existing data (Alice and Bob) without affecting +the pre-existing entries. This strategy is ideal when only new data needs to be added, preserving the historical data. + +### 3. Append new records based on timestamp ("created_at") + +This strategy appends only new records to the table based on a date/timestamp field. It is useful for scenarios where records +are created with a timestamp, and you want to load only those records created after a certain date. + +Here’s a walkthrough: + +1. The initial dataset, named "contact", in the SQL source looks like this: + + | id | name | created_at | + | --- | --- | --- | + | 1 | Alice | 2024-07-01 00:00:00 | + | 2 | Bob | 2024-07-02 00:00:00 | + +2. The python code illustrates the process of loading data from an SQL source into BigQuery using the `dlt` pipeline. Please +note the `write_disposition = "append"`, with `created_at` being used as the incremental parameter. + + ```py + def load_incremental_timestamp_table_resource() -> None: + """Load a table incrementally based on created_at timestamp.""" + pipeline = dlt.pipeline( + pipeline_name="mysql_databasecdc", + destination='bigquery', + dataset_name="dlt_contacts", + ) + + # Load table "contact", incrementally starting at a given timestamp + source = sql_database().with_resources("contact") + source.contact.apply_hints(incremental=dlt.sources.incremental( + "created_at", initial_value=datetime(2024, 4, 1, 0, 0, 0))) + + # Run the pipeline + info = pipeline.run(source, write_disposition="append") + + # Print the info + print(info) + + load_incremental_timestamp_table_resource() + ``` + +3. After running the `dlt` pipeline, the data loaded into BigQuery "contact" table looks like: + + | Row | id | name | created_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 1 | Alice | 2024-07-01 00:00:00 UTC | 1721878309.021546 | 5H8ca6C89umxHA | + | 2 | 2 | Bob | 2024-07-02 00:00:00 UTC | 1721878309.021546 | M61j4aOSqs4k2w | + +4. Next, the "contact" table in the SQL source is updated—two new rows are added, and the row with `id = 2` is removed. +The updated data source now presents itself as follows: + + | id | name | created_at | + | --- | --- | --- | + | 1 | Alice | 2024-07-01 00:00:00 | + | 3 | Charlie | 2024-07-03 00:00:00 | + | 4 | Dave | 2024-07-04 00:00:00 | + +5. The "contact" table created in BigQuery after running the pipeline again: + + | Row | id | name | created_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 1 | Alice | 2024-07-01 00:00:00 UTC | 1721878309.021546 | Petj6R+B/63sWA | + | 2 | 2 | Bob | 2024-07-02 00:00:00 UTC | 1721878309.021546 | 3Rr3VmY+av+Amw | + | 3 | 3 | Charlie | 2024-07-03 00:00:00 UTC | 1721878309.021546 | L/MnhG19xeMrvQ | + | 4 | 4 | Dave | 2024-07-04 00:00:00 UTC | 1721878309.021546 | W6ZdfvTzfRXlsA | + +**What happened?** + +The pipeline adds new records (Charlie and Dave) that have a `created_at` timestamp after the specified initial value while +retaining the existing data (Alice and Bob). This approach is useful for loading data incrementally based on when it was created. + +### 4. Merge (Update/Insert) records based on timestamp ("last_modified_at") and ID + +This strategy merges records based on a composite key of ID and a timestamp field. It updates existing records and inserts +new ones as necessary. + +Here’s a walkthrough: + +1. The initial dataset, named ‘contact’, in the SQL source looks like this: + + | id | name | last_modified_at | + | --- | --- | --- | + | 1 | Alice | 2024-07-01 00:00:00 | + | 2 | Bob | 2024-07-02 00:00:00 | + +2. The Python code illustrates the process of loading data from an SQL source into BigQuery using the `dlt` pipeline. Please +note the `write_disposition = "merge"`, with `last_modified_at` being used as the incremental parameter. + + ```py + def load_merge_table_resource() -> None: + """Merge (update/insert) records based on last_modified_at timestamp and ID.""" + pipeline = dlt.pipeline( + pipeline_name="mysql_database", + destination='bigquery', + dataset_name="dlt_contacts", + ) + + # Merge records, 'contact' table, based on ID and last_modified_at timestamp + source = sql_database().with_resources("contact") + source.contact.apply_hints(incremental=dlt.sources.incremental( + "last_modified_at", initial_value=datetime(2024, 4, 1, 0, 0, 0)), + primary_key="id") + + # Run the pipeline + info = pipeline.run(source, write_disposition="merge") + + # Print the info + print(info) + + load_merge_table_resource() + ``` + +3. After running the `dlt` pipeline, the data loaded into BigQuery ‘contact’ table looks like: + + | Row | id | name | last_modified_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 1 | Alice | 2024-07-01 00:00:00 UTC | 1721878309.021546 | ObbVlxcly3VknQ | + | 2 | 2 | Bob | 2024-07-02 00:00:00 UTC | 1721878309.021546 | Vrlkus/haaKlEg | + +4. Next, the "contact" table in the SQL source is updated— “Alice” is updated to “Alice Updated”, and a new row “Hank” is added: + + | id | name | last_modified_at | + | --- | --- | --- | + | 1 | Alice Updated | 2024-07-08 00:00:00 | + | 3 | Hank | 2024-07-08 00:00:00 | + +5. The "contact" table created in BigQuery after running the pipeline again: + + | Row | id | name | last_modified_at | _dlt_load_id | _dlt_id | + | --- | --- | --- | --- | --- | --- | + | 1 | 2 | Bob | 2024-07-02 00:00:00 UTC | 1721878309.021546 | Cm+AcDZLqXSDHQ | + | 2 | 1 | Alice Updated | 2024-07-08 00:00:00 UTC | 1721878309.021546 | OeMLIPw7rwFG7g | + | 3 | 3 | Hank | 2024-07-08 00:00:00 UTC | 1721878309.021546 | Ttp6AI2JxqffpA | + +**What happened?** + +The pipeline updates the record for Alice with the new data, including the updated `last_modified_at` timestamp, and adds a +new record for Hank. This method is beneficial when you need to ensure that records are both updated and inserted based on a +specific timestamp and ID. + +The examples provided explain how to use `dlt` to achieve different incremental loading scenarios, highlighting the changes +before and after running each pipeline. \ No newline at end of file diff --git a/docs/website/docs/walkthroughs/add_credentials.md b/docs/website/docs/walkthroughs/add_credentials.md index 5b4f241d56..bc0fb3b409 100644 --- a/docs/website/docs/walkthroughs/add_credentials.md +++ b/docs/website/docs/walkthroughs/add_credentials.md @@ -41,12 +41,12 @@ Read more about [credential configuration.](../general-usage/credentials) To add credentials to your deployment, - either use one of the `dlt deploy` commands; -- or follow the instructions to [pass credentials via code](../general-usage/credentials/configuration#pass-credentials-as-code) -or [environment](../general-usage/credentials/config_providers#environment-provider). +- or follow the instructions to [pass credentials via code](../general-usage/credentials/advanced#examples) +or [environment](../general-usage/credentials/setup#environment-variables). ### Reading credentials from environment variables -`dlt` supports reading credentials from environment. For example, our `.dlt/secrets.toml` might look like: +`dlt` supports reading credentials from the environment. For example, our `.dlt/secrets.toml` might look like: ```toml [sources.pipedrive] @@ -63,7 +63,7 @@ client_email = "client_email" # please set me up! If dlt tries to read this from environment variables, it will use a different naming convention. -For environment variables all names are capitalized and sections are separated with a double underscore "__". +For environment variables, all names are capitalized and sections are separated with a double underscore "__". For example, for the secrets mentioned above, we would need to set them in the environment: diff --git a/docs/website/docs/walkthroughs/create-a-pipeline.md b/docs/website/docs/walkthroughs/create-a-pipeline.md index cbbbd73fc3..d463921319 100644 --- a/docs/website/docs/walkthroughs/create-a-pipeline.md +++ b/docs/website/docs/walkthroughs/create-a-pipeline.md @@ -8,7 +8,7 @@ keywords: [how to, create a pipeline, rest client] This guide walks you through creating a pipeline that uses our [REST API Client](../general-usage/http/rest-client) to connect to [DuckDB](../dlt-ecosystem/destinations/duckdb). -:::tip +:::tip We're using DuckDB as a destination here, but you can adapt the steps to any [source](https://dlthub.com/docs/dlt-ecosystem/verified-sources/) and [destination](https://dlthub.com/docs/dlt-ecosystem/destinations/) by using the [command](../reference/command-line-interface#dlt-init) `dlt init ` and tweaking the pipeline accordingly. ::: @@ -63,7 +63,7 @@ api_secret_key = '' This token will be used by `github_api_source()` to authenticate requests. The **secret name** corresponds to the **argument name** in the source function. -Below `api_secret_key` [will get its value](../general-usage/credentials/configuration#allow-dlt-to-pass-the-config-and-secrets-automatically) +Below `api_secret_key` [will get its value](../general-usage/credentials/advanced) from `secrets.toml` when `github_api_source()` is called. ```py diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 465212cae6..921c3c0dc4 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -136,6 +136,19 @@ const sidebars = { 'reference/explainers/how-dlt-works', 'general-usage/resource', 'general-usage/source', + { + type: 'category', + label: 'Configuration and secrets', + link: { + type: 'doc', + id: 'general-usage/credentials/index', + }, + items: [ + 'general-usage/credentials/setup', + 'general-usage/credentials/advanced', + 'general-usage/credentials/complex_types', + ] + }, 'general-usage/pipeline', 'general-usage/destination', 'general-usage/destination-tables', @@ -159,20 +172,6 @@ const sidebars = { 'general-usage/naming-convention', 'general-usage/schema-contracts', 'general-usage/schema-evolution', - { - type: 'category', - label: 'Configuration', - link: { - type: 'generated-index', - title: 'Configuration', - slug: 'general-usage/credentials', - }, - items: [ - 'general-usage/credentials/configuration', - 'general-usage/credentials/config_providers', - 'general-usage/credentials/config_specs', - ] - }, 'build-a-pipeline-tutorial', 'reference/performance', { @@ -207,6 +206,7 @@ const sidebars = { items: [ 'walkthroughs/create-a-pipeline', 'walkthroughs/add-a-verified-source', + 'walkthroughs/add-incremental-configuration', 'walkthroughs/add_credentials', 'walkthroughs/run-a-pipeline', 'walkthroughs/adjust-a-schema', diff --git a/pyproject.toml b/pyproject.toml index 45f6297b9c..f33bbbefcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.5.2" +version = "0.5.4a0" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] diff --git a/tests/cases.py b/tests/cases.py index aa2e8ed494..54a8126754 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -178,7 +178,7 @@ def table_update_and_row( Optionally exclude some data types from the schema and row. """ column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) - data_row = deepcopy(TABLE_ROW_ALL_DATA_TYPES) + data_row = deepcopy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) exclude_col_names = list(exclude_columns or []) if exclude_types: exclude_col_names.extend( @@ -203,7 +203,7 @@ def assert_all_data_types_row( # content must equal # print(db_row) schema = schema or TABLE_UPDATE_COLUMNS_SCHEMA - expected_row = expected_row or TABLE_ROW_ALL_DATA_TYPES + expected_row = expected_row or TABLE_ROW_ALL_DATA_TYPES_DATETIMES # Include only columns requested in schema if isinstance(db_row, dict): @@ -274,8 +274,8 @@ def assert_all_data_types_row( # then it must be json db_mapping["col9"] = json.loads(db_mapping["col9"]) - if "col10" in db_mapping: - db_mapping["col10"] = db_mapping["col10"].isoformat() + # if "col10" in db_mapping: + # db_mapping["col10"] = db_mapping["col10"].isoformat() if "col11" in db_mapping: db_mapping["col11"] = ensure_pendulum_time(db_mapping["col11"]).isoformat() diff --git a/tests/cli/test_pipeline_command.py b/tests/cli/test_pipeline_command.py index 1f8e2ff4f3..82d74299f8 100644 --- a/tests/cli/test_pipeline_command.py +++ b/tests/cli/test_pipeline_command.py @@ -196,6 +196,7 @@ def test_pipeline_command_failed_jobs(repo_dir: str, project_files: FileStorage) def test_pipeline_command_drop_partial_loads(repo_dir: str, project_files: FileStorage) -> None: init_command.init_command("chess", "dummy", False, repo_dir) + os.environ["EXCEPTION_PROB"] = "1.0" try: pipeline = dlt.attach(pipeline_name="chess_pipeline") @@ -203,14 +204,22 @@ def test_pipeline_command_drop_partial_loads(repo_dir: str, project_files: FileS except Exception as e: print(e) - # now run the pipeline - os.environ["EXCEPTION_PROB"] = "1.0" - os.environ["FAIL_IN_INIT"] = "False" - os.environ["TIMEOUT"] = "1.0" venv = Venv.restore_current() with pytest.raises(CalledProcessError) as cpe: print(venv.run_script("chess_pipeline.py")) - assert "Dummy job status raised exception" in cpe.value.stdout + assert "PipelineStepFailed" in cpe.value.stdout + + # complete job manually to make a partial load + pipeline = dlt.attach(pipeline_name="chess_pipeline") + load_storage = pipeline._get_load_storage() + load_id = load_storage.normalized_packages.list_packages()[0] + job = load_storage.normalized_packages.list_new_jobs(load_id)[0] + load_storage.normalized_packages.start_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) + load_storage.normalized_packages.complete_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) with io.StringIO() as buf, contextlib.redirect_stdout(buf): pipeline_command.pipeline_command("info", "chess_pipeline", None, 1) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index 45bc8d157e..8c4d5a439b 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -21,34 +21,69 @@ clear_destination_state, ) -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.common.storages.utils import ( + start_loading_file, + assert_package_info, + load_storage, + start_loading_files, +) from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage def test_is_partially_loaded(load_storage: LoadStorage) -> None: - load_id, file_name = start_loading_file( - load_storage, [{"content": "a"}, {"content": "b"}], start_job=False + load_id, file_names = start_loading_files( + load_storage, [{"content": "a"}, {"content": "b"}], start_job=False, file_count=2 ) info = load_storage.get_load_package_info(load_id) # all jobs are new assert PackageStorage.is_package_partially_loaded(info) is False - # start job - load_storage.normalized_packages.start_job(load_id, file_name) + # start one job + load_storage.normalized_packages.start_job(load_id, file_names[0]) info = load_storage.get_load_package_info(load_id) - assert PackageStorage.is_package_partially_loaded(info) is True + assert PackageStorage.is_package_partially_loaded(info) is False # complete job - load_storage.normalized_packages.complete_job(load_id, file_name) + load_storage.normalized_packages.complete_job(load_id, file_names[0]) + info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is True + # start second job + load_storage.normalized_packages.start_job(load_id, file_names[1]) info = load_storage.get_load_package_info(load_id) assert PackageStorage.is_package_partially_loaded(info) is True + # finish second job, now not partial anymore + load_storage.normalized_packages.complete_job(load_id, file_names[1]) + info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is False + # must complete package load_storage.complete_load_package(load_id, False) info = load_storage.get_load_package_info(load_id) assert PackageStorage.is_package_partially_loaded(info) is False - # abort package + # abort package (will never be partially loaded) load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) load_storage.complete_load_package(load_id, True) info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is False + + # abort partially loaded will stay partially loaded + load_id, file_names = start_loading_files( + load_storage, [{"content": "a"}, {"content": "b"}], start_job=False, file_count=2 + ) + load_storage.normalized_packages.start_job(load_id, file_names[0]) + load_storage.normalized_packages.complete_job(load_id, file_names[0]) + load_storage.complete_load_package(load_id, True) + info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is True + + # failed jobs will also result in partial loads, if one job is completed + load_id, file_names = start_loading_files( + load_storage, [{"content": "a"}, {"content": "b"}], start_job=False, file_count=2 + ) + load_storage.normalized_packages.start_job(load_id, file_names[0]) + load_storage.normalized_packages.complete_job(load_id, file_names[0]) + load_storage.normalized_packages.start_job(load_id, file_names[1]) + load_storage.normalized_packages.fail_job(load_id, file_names[1], "much broken, so bad") + info = load_storage.get_load_package_info(load_id) assert PackageStorage.is_package_partially_loaded(info) is True diff --git a/tests/common/storages/test_load_storage.py b/tests/common/storages/test_load_storage.py index 49deaff23e..bdcec4ceb2 100644 --- a/tests/common/storages/test_load_storage.py +++ b/tests/common/storages/test_load_storage.py @@ -8,7 +8,12 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.load_package import create_load_id -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.common.storages.utils import ( + start_loading_file, + assert_package_info, + load_storage, + start_loading_files, +) from tests.utils import write_version, autouse_test_storage diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index 1b5a68948b..baac3b7af5 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -157,25 +157,38 @@ def write_temp_job_file( return Path(file_name).name -def start_loading_file( - s: LoadStorage, content: Sequence[StrAny], start_job: bool = True -) -> Tuple[str, str]: +def start_loading_files( + s: LoadStorage, content: Sequence[StrAny], start_job: bool = True, file_count: int = 1 +) -> Tuple[str, List[str]]: load_id = uniq_id() s.new_packages.create_package(load_id) # write test file - item_storage = s.create_item_storage(DataWriter.writer_spec_from_file_format("jsonl", "object")) - file_name = write_temp_job_file( - item_storage, s.storage, load_id, "mock_table", None, uniq_id(), content - ) + file_names: List[str] = [] + for _ in range(0, file_count): + item_storage = s.create_item_storage( + DataWriter.writer_spec_from_file_format("jsonl", "object") + ) + file_name = write_temp_job_file( + item_storage, s.storage, load_id, "mock_table", None, uniq_id(), content + ) + file_names.append(file_name) # write schema and schema update s.new_packages.save_schema(load_id, Schema("mock")) s.new_packages.save_schema_updates(load_id, {}) s.commit_new_load_package(load_id) - assert_package_info(s, load_id, "normalized", "new_jobs") + assert_package_info(s, load_id, "normalized", "new_jobs", jobs_count=file_count) if start_job: - s.normalized_packages.start_job(load_id, file_name) - assert_package_info(s, load_id, "normalized", "started_jobs") - return load_id, file_name + for file_name in file_names: + s.normalized_packages.start_job(load_id, file_name) + assert_package_info(s, load_id, "normalized", "started_jobs") + return load_id, file_names + + +def start_loading_file( + s: LoadStorage, content: Sequence[StrAny], start_job: bool = True +) -> Tuple[str, str]: + load_id, file_names = start_loading_files(s, content, start_job) + return load_id, file_names[0] def assert_package_info( diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index e6e2ecad2c..f04820bf36 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -351,6 +351,22 @@ def test_normalize_dataset_name() -> None: == "set_barba_papa" ) + # test dataset_name_normalization false + assert ( + DestinationClientDwhConfiguration(enable_dataset_name_normalization=False) + ._bind_dataset_name(dataset_name="BarbaPapa__Ba", default_schema_name="default") + .normalize_dataset_name(Schema("default")) + == "BarbaPapa__Ba" + ) + + # test dataset_name_normalization default is true + assert ( + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="BarbaPapa__Ba", default_schema_name="default") + .normalize_dataset_name(Schema("default")) + == "barba_papa_ba" + ) + def test_normalize_staging_dataset_name() -> None: # default normalized staging dataset @@ -388,6 +404,24 @@ def test_normalize_staging_dataset_name() -> None: == "static_staging" ) + # test dataset_name_normalization false + assert ( + DestinationClientDwhConfiguration( + enable_dataset_name_normalization=False, staging_dataset_name_layout="%s__Staging" + ) + ._bind_dataset_name(dataset_name="BarbaPapa__Ba", default_schema_name="default") + .normalize_staging_dataset_name(Schema("default")) + == "BarbaPapa__Ba__Staging" + ) + + # test dataset_name_normalization default is true + assert ( + DestinationClientDwhConfiguration(staging_dataset_name_layout="%s__Staging") + ._bind_dataset_name(dataset_name="BarbaPapa__Ba", default_schema_name="default") + .normalize_staging_dataset_name(Schema("default")) + == "barba_papa_ba_staging" + ) + def test_normalize_dataset_name_none_default_schema() -> None: # if default schema is None, suffix is not added diff --git a/tests/common/test_json.py b/tests/common/test_json.py index b7d25589a7..97435b43a8 100644 --- a/tests/common/test_json.py +++ b/tests/common/test_json.py @@ -217,6 +217,15 @@ def test_json_pendulum(json_impl: SupportsJson) -> None: assert s_r == pendulum.parse(dt_str_z) +# @pytest.mark.parametrize("json_impl", _JSON_IMPL) +# def test_json_timedelta(json_impl: SupportsJson) -> None: +# from datetime import timedelta +# start_date = pendulum.parse("2005-04-02T20:37:37.358236Z") +# delta = pendulum.interval(start_date, pendulum.now()) +# assert isinstance(delta, timedelta) +# print(str(delta.as_timedelta())) + + @pytest.mark.parametrize("json_impl", _JSON_IMPL) def test_json_named_tuple(json_impl: SupportsJson) -> None: assert ( diff --git a/tests/libs/test_deltalake.py b/tests/libs/test_deltalake.py index a162ff427b..3e2d7cc3f6 100644 --- a/tests/libs/test_deltalake.py +++ b/tests/libs/test_deltalake.py @@ -2,11 +2,11 @@ from typing import Iterator, Tuple, Union, cast import pytest -from deltalake import DeltaTable import dlt from dlt.common.libs.pyarrow import pyarrow as pa from dlt.common.libs.deltalake import ( + DeltaTable, write_delta_table, _deltalake_storage_options, ) diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index a74ab11860..c92f18e159 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -19,7 +19,7 @@ from dlt.common.schema.utils import new_table from dlt.common.storages import FileStorage from dlt.common.utils import digest128, uniq_id, custom_environ - +from dlt.common.destination.reference import RunnableLoadJob from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException @@ -32,6 +32,7 @@ prepare_table, yield_client_with_storage, cm_yield_client_with_storage, + cm_yield_client, ) # mark all tests as essential, do not remove @@ -53,6 +54,18 @@ def auto_delete_storage() -> None: delete_test_storage() +@pytest.fixture +def bigquery_project_id() -> Iterator[str]: + project_id = "different_project_id" + project_id_key = "DESTINATION__BIGQUERY__PROJECT_ID" + saved_project_id = os.environ.get(project_id_key) + os.environ[project_id_key] = project_id + yield project_id + del os.environ[project_id_key] + if saved_project_id: + os.environ[project_id_key] = saved_project_id + + def test_service_credentials_with_default(environment: Any) -> None: gcpc = GcpServiceAccountCredentials() # resolve will miss values and try to find default credentials on the machine @@ -247,6 +260,21 @@ def test_bigquery_configuration() -> None: ) +def test_bigquery_different_project_id(bigquery_project_id) -> None: + """Test scenario when bigquery project_id different from gcp credentials project_id.""" + config = resolve_configuration( + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), + ) + assert config.project_id == bigquery_project_id + with cm_yield_client( + "bigquery", + dataset_name="dataset", + default_config_values={"project_id": bigquery_project_id}, + ) as client: + assert bigquery_project_id in client.sql_client.catalog_name() + + def test_bigquery_autodetect_configuration(client: BigQueryClient) -> None: # no schema autodetect assert client._should_autodetect_schema("event_slot") is False @@ -268,30 +296,7 @@ def test_bigquery_autodetect_configuration(client: BigQueryClient) -> None: assert client._should_autodetect_schema("event_slot__values") is True -def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: - # non existing job - with pytest.raises(LoadJobNotExistsException): - client.restore_file_load(f"{uniq_id()}.") - - # bad name - with pytest.raises(LoadJobTerminalException): - client.restore_file_load("!!&*aaa") - - user_table_name = prepare_table(client) - - # start a job with non-existing file - with pytest.raises(FileNotFoundError): - client.start_file_load( - client.schema.get_table(user_table_name), - f"{uniq_id()}.", - uniq_id(), - ) - - # start a job with invalid name - dest_path = file_storage.save("!!aaaa", b"data") - with pytest.raises(LoadJobTerminalException): - client.start_file_load(client.schema.get_table(user_table_name), dest_path, uniq_id()) - +def test_bigquery_job_resuming(client: BigQueryClient, file_storage: FileStorage) -> None: user_table_name = prepare_table(client) load_json = { "_dlt_id": uniq_id(), @@ -300,14 +305,23 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) "timestamp": str(pendulum.now()), } job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) + assert job._created_job # type: ignore # start a job from the same file. it should be a fallback to retrieve a job silently - r_job = client.start_file_load( - client.schema.get_table(user_table_name), - file_storage.make_full_path(job.file_name()), - uniq_id(), + r_job = cast( + RunnableLoadJob, + client.create_load_job( + client.schema.get_table(user_table_name), + file_storage.make_full_path(job.file_name()), + uniq_id(), + ), ) + + # job will be automatically found and resumed + r_job.set_run_vars(uniq_id(), client.schema, client.schema.tables[user_table_name]) + r_job.run_managed(client) assert r_job.state() == "completed" + assert r_job._resumed_job # type: ignore @pytest.mark.parametrize("location", ["US", "EU"]) @@ -325,7 +339,7 @@ def test_bigquery_location(location: str, file_storage: FileStorage, client) -> job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should be a fallback to retrieve a job silently - client.start_file_load( + client.create_load_job( client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name()), uniq_id(), diff --git a/tests/load/cases/loading/event_loop_interrupted.1234.0.jsonl b/tests/load/cases/loading/event_loop_interrupted.1234.0.jsonl new file mode 100644 index 0000000000..8baec57d5c --- /dev/null +++ b/tests/load/cases/loading/event_loop_interrupted.1234.0.jsonl @@ -0,0 +1 @@ +small file that is never read \ No newline at end of file diff --git a/tests/load/cases/loading/event_user.1234.0.jsonl b/tests/load/cases/loading/event_user.1234.0.jsonl new file mode 100644 index 0000000000..8baec57d5c --- /dev/null +++ b/tests/load/cases/loading/event_user.1234.0.jsonl @@ -0,0 +1 @@ +small file that is never read \ No newline at end of file diff --git a/tests/load/filesystem/utils.py b/tests/load/filesystem/utils.py index ce15997ed6..bb4153da5c 100644 --- a/tests/load/filesystem/utils.py +++ b/tests/load/filesystem/utils.py @@ -14,11 +14,11 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.destination.reference import LoadJob +from dlt.common.destination.reference import RunnableLoadJob from dlt.common.pendulum import timedelta, __utcnow from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.load import Load from tests.load.utils import prepare_load_package @@ -34,7 +34,7 @@ def setup_loader(dataset_name: str) -> Load: @contextmanager def perform_load( dataset_name: str, cases: Sequence[str], write_disposition: str = "append" -) -> Iterator[Tuple[FilesystemClient, List[LoadJob], str, str]]: +) -> Iterator[Tuple[FilesystemClient, List[RunnableLoadJob], str, str]]: load = setup_loader(dataset_name) load_id, schema = prepare_load_package(load.load_storage, cases, write_disposition) client: FilesystemClient = load.get_destination_client(schema) # type: ignore[assignment] @@ -54,13 +54,13 @@ def perform_load( try: jobs = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) # job execution failed - if isinstance(job, EmptyLoadJob): + if isinstance(job, FinalizedLoadJobWithFollowupJobs): raise RuntimeError(job.exception()) jobs.append(job) - yield client, jobs, root_path, load_id + yield client, jobs, root_path, load_id # type: ignore finally: try: client.drop_storage() diff --git a/tests/load/pipeline/test_duckdb.py b/tests/load/pipeline/test_duckdb.py index 3dcfffe348..2fa44d77c5 100644 --- a/tests/load/pipeline/test_duckdb.py +++ b/tests/load/pipeline/test_duckdb.py @@ -1,14 +1,23 @@ +from typing import ClassVar, Optional import pytest import os +from datetime import datetime # noqa: I251 +import dlt +from dlt.common import json +from dlt.common.libs.pydantic import DltConfig from dlt.common.schema.exceptions import SchemaIdentifierNormalizationCollision -from dlt.common.time import ensure_pendulum_datetime -from dlt.destinations.exceptions import DatabaseTerminalException +from dlt.common.time import ensure_pendulum_datetime, pendulum + +from dlt.destinations import duckdb from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import TABLE_UPDATE_ALL_INT_PRECISIONS, TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS from tests.load.utils import destinations_configs, DestinationTestConfiguration -from tests.pipeline.utils import airtable_emojis, load_table_counts +from tests.pipeline.utils import airtable_emojis, assert_data_table_counts, load_table_counts + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential @pytest.mark.parametrize( @@ -18,7 +27,6 @@ ) def test_duck_case_names(destination_config: DestinationTestConfiguration) -> None: # we want to have nice tables - # dlt.config["schema.naming"] = "duck_case" os.environ["SCHEMA__NAMING"] = "duck_case" pipeline = destination_config.setup_pipeline("test_duck_case_names") # create tables and columns with emojis and other special characters @@ -125,3 +133,132 @@ def test_duck_precision_types(destination_config: DestinationTestConfiguration) table_row.pop("_dlt_id") table_row.pop("_dlt_load_id") assert table_row == row[0] + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_new_nested_prop_parquet(destination_config: DestinationTestConfiguration) -> None: + from pydantic import BaseModel + + class EventDetail(BaseModel): + detail_id: str + is_complete: bool + + class EventV1(BaseModel): + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True} + + ver: int + id: str # noqa + details: EventDetail + + duck_factory = duckdb("_storage/test_duck.db") + + pipeline = destination_config.setup_pipeline( + "test_new_nested_prop_parquet", dataset_name="test_dataset" + ) + pipeline.destination = duck_factory # type: ignore + + event = {"ver": 1, "id": "id1", "details": {"detail_id": "detail_1", "is_complete": False}} + + info = pipeline.run( + [event], + table_name="events", + columns=EventV1, + loader_file_format="parquet", + schema_contract="evolve", + ) + info.raise_on_failed_jobs() + print(pipeline.default_schema.to_pretty_yaml()) + + # we will use a different pipeline with a separate schema but writing to the same dataset and to the same table + # the table schema is identical to the previous one with a single field ("time") added + # this will create a different order of columns than in the destination database ("time" will map to "_dlt_id") + # duckdb copies columns by column index so that will fail + + class EventDetailV2(BaseModel): + detail_id: str + is_complete: bool + time: Optional[datetime] + + class EventV2(BaseModel): + dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True} + + ver: int + id: str # noqa + details: EventDetailV2 + + event["details"]["time"] = pendulum.now() # type: ignore + + pipeline = destination_config.setup_pipeline( + "test_new_nested_prop_parquet_2", dataset_name="test_dataset" + ) + pipeline.destination = duck_factory # type: ignore + info = pipeline.run( + [event], + table_name="events", + columns=EventV2, + loader_file_format="parquet", + schema_contract="evolve", + ) + info.raise_on_failed_jobs() + print(pipeline.default_schema.to_pretty_yaml()) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_jsonl_reader(destination_config: DestinationTestConfiguration) -> None: + pipeline = destination_config.setup_pipeline("test_jsonl_reader") + + data = [{"a": 1, "b": 2}, {"a": 1}] + info = pipeline.run(data, table_name="data", loader_file_format="jsonl") + info.raise_on_failed_jobs() + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_provoke_parallel_parquet_same_table( + destination_config: DestinationTestConfiguration, +) -> None: + @dlt.resource(name="events", file_format="parquet") + def _get_shuffled_events(repeat: int = 1): + for _ in range(repeat): + with open( + "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + ) as f: + issues = json.load(f) + yield issues + + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "200" + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "200" + + pipeline = destination_config.setup_pipeline("test_provoke_parallel_parquet_same_table") + + info = pipeline.run(_get_shuffled_events(50)) + info.raise_on_failed_jobs() + assert_data_table_counts( + pipeline, + expected_counts={ + "events": 5000, + "events__payload__pull_request__base__repo__topics": 14500, + "events__payload__commits": 3850, + "events__payload__pull_request__requested_reviewers": 1200, + "events__payload__pull_request__labels": 1300, + "events__payload__issue__labels": 150, + "events__payload__issue__assignees": 50, + }, + ) + metrics = pipeline.last_trace.last_normalize_info.metrics[ + pipeline.last_trace.last_normalize_info.loads_ids[0] + ][0] + event_files = [m for m in metrics["job_metrics"].keys() if m.startswith("events.")] + assert len(event_files) == 5000 // 200 + assert all(m.endswith("parquet") for m in event_files) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 7ad571f2aa..759f443546 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -19,6 +19,7 @@ from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.load.exceptions import LoadClientJobRetry from tests.cases import arrow_table_all_data_types, table_update_and_row, assert_all_data_types_row from tests.common.utils import load_json_case @@ -29,6 +30,7 @@ DestinationTestConfiguration, MEMORY_BUCKET, FILE_BUCKET, + AZ_BUCKET, ) from tests.pipeline.utils import load_table_counts, assert_load_info, load_tables_to_dicts @@ -242,7 +244,11 @@ def foo(): with pytest.raises(PipelineStepFailed) as pip_ex: pipeline.run(foo()) - assert isinstance(pip_ex.value.__context__, DependencyVersionException) + assert isinstance(pip_ex.value.__context__, LoadClientJobRetry) + assert ( + "`pyarrow>=17.0.0` is needed for `delta` table format on `filesystem` destination" + in pip_ex.value.__context__.retry_message + ) @pytest.mark.essential @@ -264,7 +270,7 @@ def test_delta_table_core( Tests `append` and `replace` write dispositions (`merge` is tested elsewhere). """ - from tests.pipeline.utils import _get_delta_table + from dlt.common.libs.deltalake import get_delta_tables # create resource that yields rows with all data types column_schemas, row = table_update_and_row() @@ -303,12 +309,56 @@ def data_types(): # should do logical replace, increasing the table version info = pipeline.run(data_types(), write_disposition="replace") assert_load_info(info) - client = cast(FilesystemClient, pipeline.destination_client()) - assert _get_delta_table(client, "data_types").version() == 2 + assert get_delta_tables(pipeline, "data_types")["data_types"].version() == 2 rows = load_tables_to_dicts(pipeline, "data_types", exclude_system_cols=True)["data_types"] assert len(rows) == 10 +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) +def test_delta_table_does_not_contain_job_files( + destination_config: DestinationTestConfiguration, +) -> None: + """Asserts Parquet job files do not end up in Delta table.""" + + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + @dlt.resource(table_format="delta") + def delta_table(): + yield [{"foo": 1}] + + # create Delta table + info = pipeline.run(delta_table()) + assert_load_info(info) + + # get Parquet jobs + completed_jobs = info.load_packages[0].jobs["completed_jobs"] + parquet_jobs = [ + job + for job in completed_jobs + if job.job_file_info.table_name == "delta_table" and job.file_path.endswith(".parquet") + ] + assert len(parquet_jobs) == 1 + + # get Parquet files in Delta table folder + with pipeline.destination_client() as client: + assert isinstance(client, FilesystemClient) + table_dir = client.get_table_dir("delta_table") + parquet_files = [f for f in client.fs_client.ls(table_dir) if f.endswith(".parquet")] + assert len(parquet_files) == 1 + + # Parquet file should not be the job file + file_id = parquet_jobs[0].job_file_info.file_id + assert file_id not in parquet_files[0] + + @pytest.mark.parametrize( "destination_config", destinations_configs( @@ -326,7 +376,7 @@ def test_delta_table_multiple_files( Files should be loaded into the Delta table in a single commit. """ - from tests.pipeline.utils import _get_delta_table + from dlt.common.libs.deltalake import get_delta_tables os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "2" # force multiple files @@ -350,8 +400,7 @@ def delta_table(): assert len(delta_table_parquet_jobs) == 5 # 10 records, max 2 per file # all 10 records should have been loaded into a Delta table in a single commit - client = cast(FilesystemClient, pipeline.destination_client()) - assert _get_delta_table(client, "delta_table").version() == 0 + assert get_delta_tables(pipeline, "delta_table")["delta_table"].version() == 0 rows = load_tables_to_dicts(pipeline, "delta_table", exclude_system_cols=True)["delta_table"] assert len(rows) == 10 @@ -442,6 +491,91 @@ def complex_table(): ), ids=lambda x: x.name, ) +def test_delta_table_partitioning( + destination_config: DestinationTestConfiguration, +) -> None: + """Tests partitioning for `delta` table format.""" + + from dlt.common.libs.deltalake import get_delta_tables + from tests.pipeline.utils import users_materialize_table_schema + + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + # zero partition columns + @dlt.resource(table_format="delta") + def zero_part(): + yield {"foo": 1, "bar": 1} + + info = pipeline.run(zero_part()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "zero_part")["zero_part"] + assert dt.metadata().partition_columns == [] + assert load_table_counts(pipeline, "zero_part")["zero_part"] == 1 + + # one partition column + @dlt.resource(table_format="delta", columns={"c1": {"partition": True}}) + def one_part(): + yield [ + {"c1": "foo", "c2": 1}, + {"c1": "foo", "c2": 2}, + {"c1": "bar", "c2": 3}, + {"c1": "baz", "c2": 4}, + ] + + info = pipeline.run(one_part()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "one_part")["one_part"] + assert dt.metadata().partition_columns == ["c1"] + assert load_table_counts(pipeline, "one_part")["one_part"] == 4 + + # two partition columns + @dlt.resource( + table_format="delta", columns={"c1": {"partition": True}, "c2": {"partition": True}} + ) + def two_part(): + yield [ + {"c1": "foo", "c2": 1, "c3": True}, + {"c1": "foo", "c2": 2, "c3": True}, + {"c1": "bar", "c2": 1, "c3": True}, + {"c1": "baz", "c2": 1, "c3": True}, + ] + + info = pipeline.run(two_part()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "two_part")["two_part"] + assert dt.metadata().partition_columns == ["c1", "c2"] + assert load_table_counts(pipeline, "two_part")["two_part"] == 4 + + # test partitioning with empty source + users_materialize_table_schema.apply_hints( + table_format="delta", + columns={"id": {"partition": True}}, + ) + info = pipeline.run(users_materialize_table_schema()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "users")["users"] + assert dt.metadata().partition_columns == ["id"] + assert load_table_counts(pipeline, "users")["users"] == 0 + + # changing partitioning after initial table creation is not supported + zero_part.apply_hints(columns={"foo": {"partition": True}}) + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(zero_part()) + assert isinstance(pip_ex.value.__context__, LoadClientJobRetry) + assert "partitioning" in pip_ex.value.__context__.retry_message + dt = get_delta_tables(pipeline, "zero_part")["zero_part"] + assert dt.metadata().partition_columns == [] + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET, AZ_BUCKET), + ), + ids=lambda x: x.name, +) def test_delta_table_empty_source( destination_config: DestinationTestConfiguration, ) -> None: @@ -450,8 +584,8 @@ def test_delta_table_empty_source( Tests both empty Arrow table and `dlt.mark.materialize_table_schema()`. """ from dlt.common.libs.pyarrow import pyarrow as pa - from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_data - from tests.pipeline.utils import _get_delta_table, users_materialize_table_schema + from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_data, get_delta_tables + from tests.pipeline.utils import users_materialize_table_schema @dlt.resource(table_format="delta") def delta_table(data): @@ -476,8 +610,7 @@ def delta_table(data): # this should create empty Delta table with same schema as Arrow table info = pipeline.run(delta_table(empty_arrow_table)) assert_load_info(info) - client = cast(FilesystemClient, pipeline.destination_client()) - dt = _get_delta_table(client, "delta_table") + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] assert dt.version() == 0 dt_arrow_table = dt.to_pyarrow_table() assert dt_arrow_table.shape == (0, empty_arrow_table.num_columns) @@ -489,7 +622,7 @@ def delta_table(data): # this should load records into Delta table info = pipeline.run(delta_table(arrow_table)) assert_load_info(info) - dt = _get_delta_table(client, "delta_table") + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] assert dt.version() == 1 dt_arrow_table = dt.to_pyarrow_table() assert dt_arrow_table.shape == (2, empty_arrow_table.num_columns) @@ -505,7 +638,7 @@ def delta_table(data): info = pipeline.run(delta_table(empty_arrow_table_2)) assert_load_info(info) - dt = _get_delta_table(client, "delta_table") + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] assert dt.version() == 1 # still 1, no new commit was done dt_arrow_table = dt.to_pyarrow_table() assert dt_arrow_table.shape == (2, empty_arrow_table.num_columns) # shape did not change @@ -517,7 +650,7 @@ def delta_table(data): users_materialize_table_schema.apply_hints(table_format="delta") info = pipeline.run(users_materialize_table_schema()) assert_load_info(info) - dt = _get_delta_table(client, "users") + dt = get_delta_tables(pipeline, "users")["users"] assert dt.version() == 0 dt_arrow_table = dt.to_pyarrow_table() assert dt_arrow_table.num_rows == 0 @@ -601,6 +734,70 @@ def github_events(): assert len(completed_jobs) == 2 * 20 + 1 +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET, AZ_BUCKET), + ), + ids=lambda x: x.name, +) +def test_delta_table_get_delta_tables_helper( + destination_config: DestinationTestConfiguration, +) -> None: + """Tests `get_delta_tables` helper function.""" + from dlt.common.libs.deltalake import DeltaTable, get_delta_tables + + @dlt.resource(table_format="delta") + def foo_delta(): + yield [{"foo": 1}, {"foo": 2}] + + @dlt.resource(table_format="delta") + def bar_delta(): + yield [{"bar": 1}] + + @dlt.resource + def baz_not_delta(): + yield [{"baz": 1}] + + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + info = pipeline.run(foo_delta()) + assert_load_info(info) + delta_tables = get_delta_tables(pipeline) + assert delta_tables.keys() == {"foo_delta"} + assert isinstance(delta_tables["foo_delta"], DeltaTable) + assert delta_tables["foo_delta"].to_pyarrow_table().num_rows == 2 + + info = pipeline.run([foo_delta(), bar_delta(), baz_not_delta()]) + assert_load_info(info) + delta_tables = get_delta_tables(pipeline) + assert delta_tables.keys() == {"foo_delta", "bar_delta"} + assert delta_tables["bar_delta"].to_pyarrow_table().num_rows == 1 + assert get_delta_tables(pipeline, "foo_delta").keys() == {"foo_delta"} + assert get_delta_tables(pipeline, "bar_delta").keys() == {"bar_delta"} + assert get_delta_tables(pipeline, "foo_delta", "bar_delta").keys() == {"foo_delta", "bar_delta"} + + # test with child table + @dlt.resource(table_format="delta") + def parent_delta(): + yield [{"foo": 1, "child": [1, 2, 3]}] + + info = pipeline.run(parent_delta()) + assert_load_info(info) + delta_tables = get_delta_tables(pipeline) + assert "parent_delta__child" in delta_tables.keys() + assert delta_tables["parent_delta__child"].to_pyarrow_table().num_rows == 3 + + # test invalid input + with pytest.raises(ValueError): + get_delta_tables(pipeline, "baz_not_delta") + + with pytest.raises(ValueError): + get_delta_tables(pipeline, "non_existing_table") + + TEST_LAYOUTS = ( "{schema_name}/{table_name}/{load_id}.{file_id}.{ext}", "{schema_name}.{table_name}.{load_id}.{file_id}.{ext}", diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 63188d4f5e..b2197dd273 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -11,7 +11,11 @@ from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext from dlt.common.schema.utils import has_table_seen_data -from dlt.common.schema.exceptions import SchemaCorruptedException +from dlt.common.schema.exceptions import ( + SchemaCorruptedException, + UnboundColumnException, + CannotCoerceNullException, +) from dlt.common.schema.typing import TLoaderMergeStrategy from dlt.common.typing import StrAny from dlt.common.utils import digest128 @@ -20,6 +24,7 @@ from dlt.extract import DltResource from dlt.sources.helpers.transform import skip_first, take_first from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.normalize.exceptions import NormalizeJobFailed from tests.pipeline.utils import ( assert_load_info, @@ -445,44 +450,6 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - assert github_1_counts["issues"] == 100 - 45 + 10 -@pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name -) -def test_merge_keys_non_existing_columns(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("github_3", dev_mode=True) - github_data = github() - # set keys names that do not exist in the data - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) - # skip first 45 rows - github_data.load_issues.add_filter(skip_first(45)) - info = p.run(github_data, loader_file_format=destination_config.file_format) - assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) - assert github_1_counts["issues"] == 100 - 45 - assert ( - p.default_schema.tables["issues"]["columns"]["m_a1"].items() - > {"merge_key": True, "nullable": False}.items() - ) - - # for non merge destinations we just check that the run passes - if not destination_config.supports_merge: - return - - # all the keys are invalid so the merge falls back to append - github_data = github() - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) - github_data.load_issues.add_filter(take_first(1)) - info = p.run(github_data, loader_file_format=destination_config.file_format) - assert_load_info(info) - github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) - assert github_2_counts["issues"] == 100 - 45 + 1 - with p._sql_job_client(p.default_schema) as job_c: - _, storage_cols = job_c.get_storage_table("issues") - storage_cols = normalize_storage_table_cols("issues", storage_cols, p.default_schema) - assert "url" in storage_cols - assert "m_a1" not in storage_cols # unbound columns were not created - - @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, file_format="parquet"), @@ -1242,3 +1209,51 @@ def r(): with pytest.raises(PipelineStepFailed) as pip_ex: p.run(r()) assert isinstance(pip_ex.value.__context__, SchemaCorruptedException) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_missing_merge_key_column(destination_config: DestinationTestConfiguration) -> None: + """Merge key is not present in data, error is raised""" + + @dlt.resource(merge_key="not_a_column", write_disposition={"disposition": "merge"}) + def merging_test_table(): + yield {"foo": "bar"} + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + with pytest.raises(PipelineStepFailed) as pip_ex: + p.run(merging_test_table()) + + ex = pip_ex.value + assert ex.step == "normalize" + assert isinstance(ex.__context__, UnboundColumnException) + + assert "not_a_column" in str(ex) + assert "merge key" in str(ex) + assert "merging_test_table" in str(ex) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_merge_key_null_values(destination_config: DestinationTestConfiguration) -> None: + """Merge key is present in data, but some rows have null values""" + + @dlt.resource(merge_key="id", write_disposition={"disposition": "merge"}) + def r(): + yield [{"id": 1}, {"id": None}, {"id": 2}] + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + with pytest.raises(PipelineStepFailed) as pip_ex: + p.run(r()) + + ex = pip_ex.value + assert ex.step == "normalize" + + assert isinstance(ex.__context__, NormalizeJobFailed) + assert isinstance(ex.__context__.__context__, CannotCoerceNullException) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index ffee515b90..81c9292570 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -29,6 +29,7 @@ PipelineStepFailed, ) +from tests.cases import TABLE_ROW_ALL_DATA_TYPES_DATETIMES from tests.utils import TEST_STORAGE_ROOT, data_to_item_format from tests.pipeline.utils import ( assert_data_table_counts, @@ -39,7 +40,6 @@ select_data, ) from tests.load.utils import ( - TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row, delete_dataset, @@ -844,7 +844,7 @@ def some_data(): def other_data(): yield [1, 2, 3, 4, 5] - data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) + data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) # parquet on bigquery and clickhouse does not support JSON but we still want to run the test diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 12bc69abe0..d49ce2904f 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -58,7 +58,7 @@ def norm_table_counts(counts: Dict[str, int], *child_tables: str) -> Dict[str, i offset = 1000 # keep merge key with unknown column to test replace SQL generator - @dlt.resource(name="items", write_disposition="replace", primary_key="id", merge_key="NA") + @dlt.resource(name="items", write_disposition="replace", primary_key="id") def load_items(): # will produce 3 jobs for the main table with 40 items each # 6 jobs for the sub_items diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index 8b41c354b2..065da5ce94 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -3,6 +3,7 @@ import pytest from typing import List, Dict, Any, Optional from datetime import date, datetime, timezone # noqa: I251 +from contextlib import nullcontext as does_not_raise import dlt from dlt.common.typing import TAnyDateTime @@ -45,28 +46,30 @@ def get_load_package_created_at(pipeline: dlt.Pipeline, load_info: LoadInfo) -> return reduce_pendulum_datetime_precision(created_at, caps.timestamp_precision) +def strip_timezone(ts: TAnyDateTime) -> pendulum.DateTime: + """Converts timezone of datetime object to UTC and removes timezone awareness.""" + return ensure_pendulum_datetime(ts).astimezone(tz=timezone.utc).replace(tzinfo=None) + + def get_table( - pipeline: dlt.Pipeline, table_name: str, sort_column: str, include_root_id: bool = True + pipeline: dlt.Pipeline, table_name: str, sort_column: str = None, include_root_id: bool = True ) -> List[Dict[str, Any]]: """Returns destination table contents as list of dictionaries.""" - def strip_timezone(ts: datetime) -> datetime: - """Converts timezone of datetime object to UTC and removes timezone awareness.""" - return ensure_pendulum_datetime(ts).astimezone(tz=timezone.utc).replace(tzinfo=None) + table = [ + { + k: strip_timezone(v) if isinstance(v, datetime) else v + for k, v in r.items() + if not k.startswith("_dlt") + or k in DEFAULT_VALIDITY_COLUMN_NAMES + or (k == "_dlt_root_id" if include_root_id else False) + } + for r in load_tables_to_dicts(pipeline, table_name)[table_name] + ] - return sorted( - [ - { - k: strip_timezone(v) if isinstance(v, datetime) else v - for k, v in r.items() - if not k.startswith("_dlt") - or k in DEFAULT_VALIDITY_COLUMN_NAMES - or (k == "_dlt_root_id" if include_root_id else False) - } - for r in load_tables_to_dicts(pipeline, table_name)[table_name] - ], - key=lambda d: d[sort_column], - ) + if sort_column is None: + return table + return sorted(table, key=lambda d: d[sort_column]) @pytest.mark.essential @@ -139,8 +142,8 @@ def r(data): assert table["columns"][from_]["x-valid-from"] # type: ignore[typeddict-item] assert table["columns"][to]["x-valid-to"] # type: ignore[typeddict-item] assert table["columns"]["_dlt_id"]["x-row-version"] # type: ignore[typeddict-item] - # _dlt_id is still unique - assert table["columns"]["_dlt_id"]["unique"] + # root table _dlt_id is not unique with `scd2` merge strategy + assert not table["columns"]["_dlt_id"]["unique"] # assert load results ts_1 = get_load_package_created_at(p, info) @@ -288,7 +291,7 @@ def r(data): {from_: ts_2, to: None, "nk": 1, "c1": "foo_updated"}, # new ] assert_records_as_set( - get_table(p, "dim_test__c2", cname), + get_table(p, "dim_test__c2"), [ {"_dlt_root_id": get_row_hash(l1_1), cname: 1}, {"_dlt_root_id": get_row_hash(l2_1), cname: 1}, # new @@ -310,7 +313,7 @@ def r(data): ts_3 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test", "c1"), + get_table(p, "dim_test"), [ {from_: ts_1, to: None, "nk": 2, "c1": "bar"}, {from_: ts_1, to: ts_2, "nk": 1, "c1": "foo"}, @@ -326,7 +329,7 @@ def r(data): {"_dlt_root_id": get_row_hash(l3_1), cname: 2}, # new {"_dlt_root_id": get_row_hash(l1_2), cname: 3}, ] - assert_records_as_set(get_table(p, "dim_test__c2", cname), exp_3) + assert_records_as_set(get_table(p, "dim_test__c2"), exp_3) # load 4 — delete a record dim_snap = [ @@ -336,7 +339,7 @@ def r(data): ts_4 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test", "c1"), + get_table(p, "dim_test"), [ {from_: ts_1, to: ts_4, "nk": 2, "c1": "bar"}, # updated {from_: ts_1, to: ts_2, "nk": 1, "c1": "foo"}, @@ -345,7 +348,7 @@ def r(data): ], ) assert_records_as_set( - get_table(p, "dim_test__c2", cname), exp_3 + get_table(p, "dim_test__c2"), exp_3 ) # deletes should not alter child tables # load 5 — insert a record @@ -357,7 +360,7 @@ def r(data): ts_5 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test", "c1"), + get_table(p, "dim_test"), [ {from_: ts_1, to: ts_4, "nk": 2, "c1": "bar"}, {from_: ts_5, to: None, "nk": 3, "c1": "baz"}, # new @@ -367,7 +370,7 @@ def r(data): ], ) assert_records_as_set( - get_table(p, "dim_test__c2", cname), + get_table(p, "dim_test__c2"), [ {"_dlt_root_id": get_row_hash(l1_1), cname: 1}, {"_dlt_root_id": get_row_hash(l2_1), cname: 1}, @@ -403,7 +406,7 @@ def r(data): info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test__c2__cc1", "value"), + get_table(p, "dim_test__c2__cc1"), [ {"_dlt_root_id": get_row_hash(l1_1), "value": 1}, {"_dlt_root_id": get_row_hash(l1_2), "value": 1}, @@ -419,7 +422,7 @@ def r(data): info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) assert_load_info(info) assert_records_as_set( - (get_table(p, "dim_test__c2__cc1", "value")), + (get_table(p, "dim_test__c2__cc1")), [ {"_dlt_root_id": get_row_hash(l1_1), "value": 1}, {"_dlt_root_id": get_row_hash(l1_2), "value": 1}, @@ -443,7 +446,7 @@ def r(data): {"_dlt_root_id": get_row_hash(l1_2), "value": 2}, {"_dlt_root_id": get_row_hash(l3_1), "value": 2}, # new ] - assert_records_as_set(get_table(p, "dim_test__c2__cc1", "value"), exp_3) + assert_records_as_set(get_table(p, "dim_test__c2__cc1"), exp_3) # load 4 — delete a record dim_snap = [ @@ -451,7 +454,7 @@ def r(data): ] info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) assert_load_info(info) - assert_records_as_set(get_table(p, "dim_test__c2__cc1", "value"), exp_3) + assert_records_as_set(get_table(p, "dim_test__c2__cc1"), exp_3) # load 5 — insert a record dim_snap = [ @@ -461,7 +464,7 @@ def r(data): info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test__c2__cc1", "value"), + get_table(p, "dim_test__c2__cc1"), [ {"_dlt_root_id": get_row_hash(l1_1), "value": 1}, {"_dlt_root_id": get_row_hash(l1_2), "value": 1}, @@ -474,6 +477,67 @@ def r(data): ) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, +) +def test_record_reinsert(destination_config: DestinationTestConfiguration) -> None: + p = destination_config.setup_pipeline("abstract", dev_mode=True) + + @dlt.resource( + table_name="dim_test", write_disposition={"disposition": "merge", "strategy": "scd2"} + ) + def r(data): + yield data + + # load 1 — initial load + dim_snap = [ + r1 := {"nk": 1, "c1": "foo", "c2": "foo", "child": [1]}, + r2 := {"nk": 2, "c1": "bar", "c2": "bar", "child": [2, 3]}, + ] + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 2 + assert load_table_counts(p, "dim_test__child")["dim_test__child"] == 3 + ts_1 = get_load_package_created_at(p, info) + + # load 2 — delete natural key 1 + dim_snap = [r2] + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 2 + assert load_table_counts(p, "dim_test__child")["dim_test__child"] == 3 + ts_2 = get_load_package_created_at(p, info) + + # load 3 — reinsert natural key 1 + dim_snap = [r1, r2] + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 3 + assert load_table_counts(p, "dim_test__child")["dim_test__child"] == 3 # no new record + ts_3 = get_load_package_created_at(p, info) + + # assert parent records + from_, to = DEFAULT_VALIDITY_COLUMN_NAMES + r1_no_child = {k: v for k, v in r1.items() if k != "child"} + r2_no_child = {k: v for k, v in r2.items() if k != "child"} + expected = [ + {**{from_: ts_1, to: ts_2}, **r1_no_child}, + {**{from_: ts_3, to: None}, **r1_no_child}, + {**{from_: ts_1, to: None}, **r2_no_child}, + ] + assert_records_as_set(get_table(p, "dim_test"), expected) + + # assert child records + expected = [ + {"_dlt_root_id": get_row_hash(r1), "value": 1}, # links to two records in parent + {"_dlt_root_id": get_row_hash(r2), "value": 2}, + {"_dlt_root_id": get_row_hash(r2), "value": 3}, + ] + assert_records_as_set(get_table(p, "dim_test__child"), expected) + + @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, subset=["duckdb"]), @@ -520,6 +584,7 @@ def r(data): "9999-12-31T00:00:00", "9999-12-31T00:00:00+00:00", "9999-12-31T00:00:00+01:00", + "i_am_not_a_timestamp", ], ) def test_active_record_timestamp( @@ -528,22 +593,126 @@ def test_active_record_timestamp( ) -> None: p = destination_config.setup_pipeline("abstract", dev_mode=True) + context = does_not_raise() + if active_record_timestamp == "i_am_not_a_timestamp": + context = pytest.raises(ValueError) # type: ignore[assignment] + + with context: + + @dlt.resource( + table_name="dim_test", + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "active_record_timestamp": active_record_timestamp, + }, + ) + def r(): + yield {"foo": "bar"} + + p.run(r()) + actual_active_record_timestamp = ensure_pendulum_datetime( + load_tables_to_dicts(p, "dim_test")["dim_test"][0]["_dlt_valid_to"] + ) + assert actual_active_record_timestamp == ensure_pendulum_datetime(active_record_timestamp) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_boundary_timestamp( + destination_config: DestinationTestConfiguration, +) -> None: + p = destination_config.setup_pipeline("abstract", dev_mode=True) + + ts1 = "2024-08-21T12:15:00+00:00" + ts2 = "2024-08-22" + ts3 = date(2024, 8, 20) # earlier than ts1 and ts2 + ts4 = "i_am_not_a_timestamp" + @dlt.resource( table_name="dim_test", write_disposition={ "disposition": "merge", "strategy": "scd2", - "active_record_timestamp": active_record_timestamp, + "boundary_timestamp": ts1, }, ) - def r(): - yield {"foo": "bar"} + def r(data): + yield data + + # load 1 — initial load + dim_snap = [ + l1_1 := {"nk": 1, "foo": "foo"}, + l1_2 := {"nk": 2, "foo": "foo"}, + ] + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 2 + from_, to = DEFAULT_VALIDITY_COLUMN_NAMES + expected = [ + {**{from_: strip_timezone(ts1), to: None}, **l1_1}, + {**{from_: strip_timezone(ts1), to: None}, **l1_2}, + ] + assert get_table(p, "dim_test", "nk") == expected + + # load 2 — different source records, different boundary timestamp + r.apply_hints( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "boundary_timestamp": ts2, + } + ) + dim_snap = [ + l2_1 := {"nk": 1, "foo": "bar"}, # natural key 1 updated + # l1_2, # natural key 2 no longer present + l2_3 := {"nk": 3, "foo": "foo"}, # new natural key + ] + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 4 + expected = [ + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_1}, # retired + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_2}, # retired + {**{from_: strip_timezone(ts2), to: None}, **l2_1}, # new + {**{from_: strip_timezone(ts2), to: None}, **l2_3}, # new + ] + assert_records_as_set(get_table(p, "dim_test"), expected) - p.run(r()) - actual_active_record_timestamp = ensure_pendulum_datetime( - load_tables_to_dicts(p, "dim_test")["dim_test"][0]["_dlt_valid_to"] + # load 3 — earlier boundary timestamp + # we naively apply any valid timestamp + # may lead to "valid from" > "valid to", as in this test case + r.apply_hints( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "boundary_timestamp": ts3, + } ) - assert actual_active_record_timestamp == ensure_pendulum_datetime(active_record_timestamp) + dim_snap = [l2_1] # natural key 3 no longer present + info = p.run(r(dim_snap)) + assert_load_info(info) + assert load_table_counts(p, "dim_test")["dim_test"] == 4 + expected = [ + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_1}, # unchanged + {**{from_: strip_timezone(ts1), to: strip_timezone(ts2)}, **l1_2}, # unchanged + {**{from_: strip_timezone(ts2), to: None}, **l2_1}, # unchanged + {**{from_: strip_timezone(ts2), to: strip_timezone(ts3)}, **l2_3}, # retired + ] + assert_records_as_set(get_table(p, "dim_test"), expected) + + # invalid boundary timestamp should raise error + with pytest.raises(ValueError): + r.apply_hints( + write_disposition={ + "disposition": "merge", + "strategy": "scd2", + "boundary_timestamp": ts4, + } + ) @pytest.mark.parametrize( @@ -633,6 +802,8 @@ def r(data): table = p.default_schema.get_table("dim_test") assert table["columns"]["row_hash"]["x-row-version"] # type: ignore[typeddict-item] assert "x-row-version" not in table["columns"]["_dlt_id"] + # _dlt_id unique constraint should not be dropped when users bring their own hash + assert table["columns"]["_dlt_id"]["unique"] # load 2 — update and delete a record dim_snap = [ diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index bb923df673..41287fcd2d 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -90,9 +90,10 @@ def test_text_too_long(client: RedshiftClient, file_storage: FileStorage) -> Non # print(len(max_len_str_b)) row_id = uniq_id() insert_values = f"('{row_id}', '{uniq_id()}', '{max_len_str}' , '{str(pendulum.now())}');" - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is psycopg2.errors.StringDataRightTruncation + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception.dbapi_exception) is psycopg2.errors.StringDataRightTruncation # type: ignore def test_wei_value(client: RedshiftClient, file_storage: FileStorage) -> None: @@ -107,9 +108,10 @@ def test_wei_value(client: RedshiftClient, file_storage: FileStorage) -> None: f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', {10**38});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is psycopg2.errors.InternalError_ + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception.dbapi_exception) is psycopg2.errors.InternalError_ # type: ignore def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None: diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index be917672f1..b55f4ceece 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -1,6 +1,6 @@ import os from concurrent.futures import ThreadPoolExecutor -from time import sleep +from time import sleep, time from unittest import mock import pytest from unittest.mock import patch @@ -10,7 +10,7 @@ from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName from dlt.common.storages.load_package import LoadJobInfo, TJobState from dlt.common.storages.load_storage import JobFileFormatUnsupported -from dlt.common.destination.reference import LoadJob, TDestination +from dlt.common.destination.reference import RunnableLoadJob, TDestination from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, get_child_tables, @@ -18,13 +18,17 @@ ) from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration -from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations import dummy, filesystem from dlt.destinations.impl.dummy import dummy as dummy_impl from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration from dlt.load import Load -from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry +from dlt.load.exceptions import ( + LoadClientJobFailed, + LoadClientJobRetry, + TableChainFollowupJobCreationFailedException, + FollowupJobCreationFailedException, +) from dlt.load.utils import get_completed_table_chain, init_client, _extend_tables_with_table_chain from tests.utils import ( @@ -42,6 +46,8 @@ "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl", ] +SMALL_FILES = ["event_user.1234.0.jsonl", "event_loop_interrupted.1234.0.jsonl"] + REMOTE_FILESYSTEM = os.path.abspath(os.path.join(TEST_STORAGE_ROOT, "_remote_filesystem")) @@ -61,20 +67,21 @@ def test_spool_job_started() -> None: load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 - jobs: List[LoadJob] = [] + jobs: List[RunnableLoadJob] = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) + assert job.state() == "completed" assert type(job) is dummy_impl.LoadDummyJob - assert job.state() == "running" + # jobs runs, but is not moved yet (loader will do this) assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( load_id, PackageStorage.STARTED_JOBS_FOLDER, job.file_name() ) ) jobs.append(job) - # still running - remaining_jobs = load.complete_jobs(load_id, jobs, schema) - assert len(remaining_jobs) == 2 + remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) + assert len(remaining_jobs) == 0 + assert len(finalized_jobs) == 2 def test_unsupported_writer_type() -> None: @@ -87,6 +94,7 @@ def test_unsupported_writer_type() -> None: def test_unsupported_write_disposition() -> None: + # tests terminal error on retrieving job load = setup_loader() load_id, schema = prepare_load_package(load.load_storage, [NORMALIZED_FILES[0]]) # mock unsupported disposition @@ -96,13 +104,36 @@ def test_unsupported_write_disposition() -> None: with ThreadPoolExecutor() as pool: load.run(pool) # job with unsupported write disp. is failed - failed_job = load.load_storage.normalized_packages.list_failed_jobs(load_id)[0] - failed_message = load.load_storage.normalized_packages.get_job_failed_message( + failed_job = load.load_storage.loaded_packages.list_failed_jobs(load_id)[0] + failed_message = load.load_storage.loaded_packages.get_job_failed_message( load_id, ParsedLoadJobFileName.parse(failed_job) ) assert "LoadClientUnsupportedWriteDisposition" in failed_message +def test_big_loadpackages() -> None: + """ + This test guards against changes in the load that exponentially makes the loads slower + """ + + load = setup_loader() + # make the loop faster by basically not sleeping + load._run_loop_sleep_duration = 0.001 + load_id, schema = prepare_load_package(load.load_storage, SMALL_FILES, jobs_per_case=500) + start_time = time() + with ThreadPoolExecutor(max_workers=20) as pool: + load.run(pool) + duration = float(time() - start_time) + + # sanity check + assert duration > 3 + # we want 1000 empty processed jobs to need less than 15 seconds total (locally it runs in 5) + assert duration < 15 + + # we should have 1000 jobs processed + assert len(dummy_impl.JOBS) == 1000 + + def test_get_new_jobs_info() -> None: load = setup_loader() load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) @@ -158,10 +189,10 @@ def test_spool_job_failed() -> None: load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) - jobs: List[LoadJob] = [] + jobs: List[RunnableLoadJob] = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) - assert type(job) is EmptyLoadJob + job = load.submit_job(f, load_id, schema) + assert type(job) is dummy_impl.LoadDummyJob assert job.state() == "failed" assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( @@ -170,8 +201,9 @@ def test_spool_job_failed() -> None: ) jobs.append(job) # complete files - remaining_jobs = load.complete_jobs(load_id, jobs, schema) + remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 + assert len(finalized_jobs) == 2 for job in jobs: assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( @@ -196,11 +228,10 @@ def test_spool_job_failed() -> None: assert len(package_info.jobs["failed_jobs"]) == 2 -def test_spool_job_failed_exception_init() -> None: +def test_spool_job_failed_terminally_exception_init() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" - os.environ["FAIL_IN_INIT"] = "true" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=True)) + load = setup_loader(client_config=DummyClientConfiguration(fail_terminally_in_init=True)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: with pytest.raises(LoadClientJobFailed) as py_ex: @@ -215,11 +246,30 @@ def test_spool_job_failed_exception_init() -> None: complete_load.assert_not_called() +def test_spool_job_failed_transiently_exception_init() -> None: + # this config fails job on start + os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" + load = setup_loader(client_config=DummyClientConfiguration(fail_transiently_in_init=True)) + load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) + with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: + with pytest.raises(LoadClientJobRetry) as py_ex: + run_all(load) + assert py_ex.value.load_id == load_id + package_info = load.load_storage.get_load_package_info(load_id) + assert package_info.state == "normalized" + # both failed - we wait till the current loop is completed and then raise + assert len(package_info.jobs["failed_jobs"]) == 0 + assert len(package_info.jobs["started_jobs"]) == 0 + assert len(package_info.jobs["new_jobs"]) == 2 + + # load id was never committed + complete_load.assert_not_called() + + def test_spool_job_failed_exception_complete() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" - os.environ["FAIL_IN_INIT"] = "false" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=False)) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with pytest.raises(LoadClientJobFailed) as py_ex: run_all(load) @@ -237,7 +287,7 @@ def test_spool_job_retry_new() -> None: load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) assert job.state() == "retry" @@ -248,8 +298,7 @@ def test_spool_job_retry_spool_new() -> None: # call higher level function that returns jobs and counts with ThreadPoolExecutor() as pool: load.pool = pool - jobs_count, jobs = load.spool_new_jobs(load_id, schema) - assert jobs_count == 2 + jobs = load.start_new_jobs(load_id, schema, []) assert len(jobs) == 2 @@ -259,24 +308,26 @@ def test_spool_job_retry_started() -> None: # dummy_impl.CLIENT_CONFIG = DummyClientConfiguration load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) - jobs: List[LoadJob] = [] + jobs: List[RunnableLoadJob] = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) assert type(job) is dummy_impl.LoadDummyJob - assert job.state() == "running" + assert job.state() == "completed" + # mock job state to make it retry + job.config.retry_prob = 1.0 + job._state = "retry" assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( load_id, PackageStorage.STARTED_JOBS_FOLDER, job.file_name() ) ) - # mock job config to make it retry - job.config.retry_prob = 1.0 jobs.append(job) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 0 - # should retry, that moves jobs into new folder - remaining_jobs = load.complete_jobs(load_id, jobs, schema) + # should retry, that moves jobs into new folder, jobs are not counted as finalized + remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 + assert len(finalized_jobs) == 0 # clear retry flag dummy_impl.JOBS = {} files = load.load_storage.normalized_packages.list_new_jobs(load_id) @@ -285,9 +336,11 @@ def test_spool_job_retry_started() -> None: for fn in load.load_storage.normalized_packages.list_new_jobs(load_id): # we failed when already running the job so retry count will increase assert ParsedLoadJobFileName.parse(fn).retry_count == 1 + + # this time it will pass for f in files: - job = Load.w_spool_job(load, f, load_id, schema) - assert job.state() == "running" + job = load.submit_job(f, load_id, schema) + assert job.state() == "completed" def test_try_retrieve_job() -> None: @@ -301,22 +354,21 @@ def test_try_retrieve_job() -> None: ) # dummy client may retrieve jobs that it created itself, jobs in started folder are unknown # and returned as terminal - with load.destination.client(schema, load.initial_client_config) as c: - job_count, jobs = load.retrieve_jobs(c, load_id) - assert job_count == 2 - for j in jobs: - assert j.state() == "failed" + jobs = load.resume_started_jobs(load_id, schema) + assert len(jobs) == 2 + for j in jobs: + assert j.state() == "failed" # new load package load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) load.pool = ThreadPoolExecutor() - jobs_count, jobs = load.spool_new_jobs(load_id, schema) - assert jobs_count == 2 + jobs = load.start_new_jobs(load_id, schema, []) # type: ignore + assert len(jobs) == 2 # now jobs are known - with load.destination.client(schema, load.initial_client_config) as c: - job_count, jobs = load.retrieve_jobs(c, load_id) - assert job_count == 2 - for j in jobs: - assert j.state() == "running" + jobs = load.resume_started_jobs(load_id, schema) + assert len(jobs) == 2 + for j in jobs: + assert j.state() == "completed" + assert len(dummy_impl.RETRIED_JOBS) == 2 def test_completed_loop() -> None: @@ -328,7 +380,6 @@ def test_completed_loop() -> None: def test_completed_loop_followup_jobs() -> None: # TODO: until we fix how we create capabilities we must set env - os.environ["CREATE_FOLLOWUP_JOBS"] = "true" load = setup_loader( client_config=DummyClientConfiguration(completed_prob=1.0, create_followup_jobs=True) ) @@ -338,6 +389,95 @@ def test_completed_loop_followup_jobs() -> None: assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_FOLLOWUP_JOBS) * 2 +def test_failing_followup_jobs() -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, create_followup_jobs=True, fail_followup_job_creation=True + ) + ) + with pytest.raises(FollowupJobCreationFailedException) as exc: + assert_complete_job(load) + # follow up job errors on main thread + assert "Failed to create followup job" in str(exc) + + # followup job fails, we have both jobs in started folder + load_id = list(dummy_impl.JOBS.values())[1]._load_id + started_files = load.load_storage.normalized_packages.list_started_jobs(load_id) + assert len(started_files) == 2 + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.RETRIED_JOBS) == 0 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + # now we can retry the same load, it will restart the two jobs and successfully create the followup jobs + load.initial_client_config.fail_followup_job_creation = False # type: ignore + assert_complete_job(load, load_id=load_id) + assert len(dummy_impl.JOBS) == 2 * 2 + assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_FOLLOWUP_JOBS) * 2 + assert len(dummy_impl.RETRIED_JOBS) == 2 + + +def test_failing_table_chain_followup_jobs() -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, + create_followup_table_chain_reference_jobs=True, + fail_table_chain_followup_job_creation=True, + ) + ) + with pytest.raises(TableChainFollowupJobCreationFailedException) as exc: + assert_complete_job(load) + # follow up job errors on main thread + assert "Failed creating table chain followup jobs for table chain with root table" in str(exc) + + # table chain followup job fails, we have both jobs in started folder + load_id = list(dummy_impl.JOBS.values())[1]._load_id + started_files = load.load_storage.normalized_packages.list_started_jobs(load_id) + assert len(started_files) == 2 + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.RETRIED_JOBS) == 0 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + # now we can retry the same load, it will restart the two jobs and successfully create the table chain followup jobs + load.initial_client_config.fail_table_chain_followup_job_creation = False # type: ignore + assert_complete_job(load, load_id=load_id) + assert len(dummy_impl.JOBS) == 2 * 2 + assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS) * 2 + assert len(dummy_impl.RETRIED_JOBS) == 2 + + +def test_failing_sql_table_chain_job() -> None: + """ + Make sure we get a useful exception from a failing sql job + """ + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, create_followup_table_chain_sql_jobs=True + ), + ) + with pytest.raises(Exception) as exc: + assert_complete_job(load) + + # sql jobs always fail because this is not an sql client, we just make sure the exception is there + assert "Failed creating table chain followup jobs for table chain with root table" in str(exc) + + +def test_successful_table_chain_jobs() -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, create_followup_table_chain_reference_jobs=True + ), + ) + # we create 10 jobs per case (for two cases) + # and expect two table chain jobs at the end + assert_complete_job(load, jobs_per_case=10) + assert len(dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS) == 2 + assert len(dummy_impl.JOBS) == 22 + + # check that we have 10 references per followup job + for _, job in dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS.items(): + assert len(job._remote_paths) == 10 # type: ignore + + def test_failed_loop() -> None: # ask to delete completed load = setup_loader( @@ -345,21 +485,18 @@ def test_failed_loop() -> None: ) # actually not deleted because one of the jobs failed assert_complete_job(load, should_delete_completed=False) - # no jobs because fail on init - assert len(dummy_impl.JOBS) == 0 + # two failed jobs + assert len(dummy_impl.JOBS) == 2 + assert list(dummy_impl.JOBS.values())[0].state() == "failed" + assert list(dummy_impl.JOBS.values())[1].state() == "failed" assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 def test_failed_loop_followup_jobs() -> None: - # TODO: until we fix how we create capabilities we must set env - os.environ["CREATE_FOLLOWUP_JOBS"] = "true" - os.environ["FAIL_IN_INIT"] = "false" # ask to delete completed load = setup_loader( delete_completed_jobs=True, - client_config=DummyClientConfiguration( - fail_prob=1.0, fail_in_init=False, create_followup_jobs=True - ), + client_config=DummyClientConfiguration(fail_prob=1.0, create_followup_jobs=True), ) # actually not deleted because one of the jobs failed assert_complete_job(load, should_delete_completed=False) @@ -381,36 +518,36 @@ def test_retry_on_new_loop() -> None: load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) with ThreadPoolExecutor() as pool: # 1st retry - load.run(pool) + with pytest.raises(LoadClientJobRetry): + load.run(pool) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 # 2nd retry - load.run(pool) + with pytest.raises(LoadClientJobRetry): + load.run(pool) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 - # jobs will be completed + # package will be completed load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) load.run(pool) - files = load.load_storage.normalized_packages.list_new_jobs(load_id) - assert len(files) == 0 - # complete package - load.run(pool) assert not load.load_storage.normalized_packages.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) ) + sleep(1) # parse the completed job names completed_path = load.load_storage.loaded_packages.get_package_path(load_id) for fn in load.load_storage.loaded_packages.storage.list_folder_files( os.path.join(completed_path, PackageStorage.COMPLETED_JOBS_FOLDER) ): - # we update a retry count in each case - assert ParsedLoadJobFileName.parse(fn).retry_count == 2 + # we update a retry count in each case (5 times for each loop run) + assert ParsedLoadJobFileName.parse(fn).retry_count == 10 def test_retry_exceptions() -> None: load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) prepare_load_package(load.load_storage, NORMALIZED_FILES) + with ThreadPoolExecutor() as pool: # 1st retry with pytest.raises(LoadClientJobRetry) as py_ex: @@ -418,7 +555,6 @@ def test_retry_exceptions() -> None: load.run(pool) # configured to retry 5 times before exception assert py_ex.value.max_retry_count == py_ex.value.retry_count == 5 - # we can do it again with pytest.raises(LoadClientJobRetry) as py_ex: while True: @@ -730,8 +866,13 @@ def test_terminal_exceptions() -> None: raise AssertionError() -def assert_complete_job(load: Load, should_delete_completed: bool = False) -> None: - load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) +def assert_complete_job( + load: Load, should_delete_completed: bool = False, load_id: str = None, jobs_per_case: int = 1 +) -> None: + if not load_id: + load_id, _ = prepare_load_package( + load.load_storage, NORMALIZED_FILES, jobs_per_case=jobs_per_case + ) # will complete all jobs timestamp = "2024-04-05T09:16:59.942779Z" mocked_timestamp = {"state": {"created_at": timestamp}} @@ -744,22 +885,7 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No ) as complete_load: with ThreadPoolExecutor() as pool: load.run(pool) - # did process schema update - assert load.load_storage.storage.has_file( - os.path.join( - load.load_storage.get_normalized_package_path(load_id), - PackageStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME, - ) - ) - # will finalize the whole package - load.run(pool) - # may have followup jobs or staging destination - if ( - load.initial_client_config.create_followup_jobs # type:ignore[attr-defined] - or load.staging_destination - ): - # run the followup jobs - load.run(pool) + # moved to loaded assert not load.load_storage.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) @@ -767,6 +893,15 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No completed_path = load.load_storage.loaded_packages.get_job_state_folder_path( load_id, "completed_jobs" ) + + # should have migrated the schema + assert load.load_storage.storage.has_file( + os.path.join( + load.load_storage.get_loaded_package_path(load_id), + PackageStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME, + ) + ) + if should_delete_completed: # package was deleted assert not load.load_storage.loaded_packages.storage.has_folder(completed_path) @@ -794,14 +929,21 @@ def setup_loader( # reset jobs for a test dummy_impl.JOBS = {} dummy_impl.CREATED_FOLLOWUP_JOBS = {} - client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") + dummy_impl.RETRIED_JOBS = {} + dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS = {} + + client_config = client_config or DummyClientConfiguration( + loader_file_format="jsonl", completed_prob=1 + ) destination: TDestination = dummy(**client_config) # type: ignore[assignment] # setup staging_system_config = None staging = None if filesystem_staging: # do not accept jsonl to not conflict with filesystem destination - client_config = client_config or DummyClientConfiguration(loader_file_format="reference") # type: ignore[arg-type] + client_config = client_config or DummyClientConfiguration( + loader_file_format="reference", completed_prob=1 + ) staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( dataset_name="dummy" ) diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 38155a8b09..a957c871bb 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -55,12 +55,18 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - f" '{str(pendulum.now())}'" + post ) - expect_load_file(client, file_storage, insert_sql + insert_values + ";", user_table_name) + expect_load_file( + client, + file_storage, + insert_sql + insert_values + ";", + user_table_name, + file_format="insert_values", + ) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 1 # insert 100 more rows query = insert_sql + (insert_values + sep) * 99 + insert_values + ";" - expect_load_file(client, file_storage, query, user_table_name) + expect_load_file(client, file_storage, query, user_table_name, file_format="insert_values") rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 101 # insert null value (single-record insert has same syntax for both writer types) @@ -69,7 +75,13 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', NULL);" ) - expect_load_file(client, file_storage, insert_sql_nc + insert_values_nc, user_table_name) + expect_load_file( + client, + file_storage, + insert_sql_nc + insert_values_nc, + user_table_name, + file_format="insert_values", + ) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 102 @@ -94,7 +106,7 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage if dtype == "redshift": # redshift does not know or psycopg does not recognize those correctly TNotNullViolation = psycopg2.errors.InternalError_ - elif dtype == "duckdb": + elif dtype in ("duckdb", "motherduck"): import duckdb TUndefinedColumn = duckdb.BinderException @@ -114,24 +126,42 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', NULL);" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is TUndefinedColumn + job = expect_load_file( + client, + file_storage, + insert_sql + insert_values, + user_table_name, + "failed", + file_format="insert_values", + ) + assert type(job._exception.dbapi_exception) is TUndefinedColumn # type: ignore # insert null value insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', NULL);" - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is TNotNullViolation + job = expect_load_file( + client, + file_storage, + insert_sql + insert_values, + user_table_name, + "failed", + file_format="insert_values", + ) + assert type(job._exception.dbapi_exception) is TNotNullViolation # type: ignore # insert wrong type insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = ( f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" {client.capabilities.escape_literal(True)});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is TDatatypeMismatch + job = expect_load_file( + client, + file_storage, + insert_sql + insert_values, + user_table_name, + "failed", + file_format="insert_values", + ) + assert type(job._exception.dbapi_exception) is TDatatypeMismatch # type: ignore # numeric overflow on bigint insert_sql = ( "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, metadata__rasa_x_id)\nVALUES\n" @@ -141,9 +171,15 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', {2**64//2});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) == TNumericValueOutOfRange + job = expect_load_file( + client, + file_storage, + insert_sql + insert_values, + user_table_name, + "failed", + file_format="insert_values", + ) + assert type(job._exception) == DatabaseTerminalException # type: ignore # numeric overflow on NUMERIC insert_sql = ( "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," @@ -158,16 +194,30 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', {below_limit});" ) - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, + file_storage, + insert_sql + insert_values, + user_table_name, + file_format="insert_values", + ) # this will raise insert_values = ( f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', {above_limit});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + job = expect_load_file( + client, + file_storage, + insert_sql + insert_values, + user_table_name, + "failed", + file_format="insert_values", + ) + assert type(job._exception) == DatabaseTerminalException # type: ignore + assert ( - type(exv.value.dbapi_exception) == psycopg2.errors.InternalError_ + type(job._exception.dbapi_exception) == psycopg2.errors.InternalError_ # type: ignore if dtype == "redshift" else TNumericValueOutOfRange ) @@ -193,7 +243,9 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - # this guarantees that we execute inserts line by line with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) - expect_load_file(client, file_storage, insert_sql, user_table_name) + expect_load_file( + client, file_storage, insert_sql, user_table_name, file_format="insert_values" + ) # print(mocked_fragments.mock_calls) # split in 10 lines assert mocked_fragments.call_count == 10 @@ -217,7 +269,9 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - client.sql_client.capabilities.max_query_length = query_length with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) - expect_load_file(client, file_storage, insert_sql, user_table_name) + expect_load_file( + client, file_storage, insert_sql, user_table_name, file_format="insert_values" + ) # split in 2 on ',' assert mocked_fragments.call_count == 2 @@ -226,7 +280,9 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - client.sql_client.capabilities.max_query_length = query_length with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) - expect_load_file(client, file_storage, insert_sql, user_table_name) + expect_load_file( + client, file_storage, insert_sql, user_table_name, file_format="insert_values" + ) # split in 2 on separator ("," or " UNION ALL") assert mocked_fragments.call_count == 2 @@ -239,7 +295,9 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - client.sql_client.capabilities.max_query_length = query_length with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) - expect_load_file(client, file_storage, insert_sql, user_table_name) + expect_load_file( + client, file_storage, insert_sql, user_table_name, file_format="insert_values" + ) # split in 2 on ',' assert mocked_fragments.call_count == 1 @@ -256,7 +314,7 @@ def assert_load_with_max_query( insert_sql = prepare_insert_statement( insert_lines, client.capabilities.insert_values_writer_type ) - expect_load_file(client, file_storage, insert_sql, user_table_name) + expect_load_file(client, file_storage, insert_sql, user_table_name, file_format="insert_values") canonical_name = client.sql_client.make_qualified_table_name(user_table_name) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == insert_lines diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 614eb17da1..06b70a49da 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -35,7 +35,6 @@ from tests.load.utils import ( TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, - TABLE_ROW_ALL_DATA_TYPES, expect_load_file, load_table, yield_client_with_storage, @@ -489,7 +488,7 @@ def test_data_writer_load(naming: str, client: SqlJobClientBase, file_storage: F # write only first row with io.BytesIO() as f: write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) - query = f.getvalue().decode() + query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] # content must equal @@ -497,7 +496,7 @@ def test_data_writer_load(naming: str, client: SqlJobClientBase, file_storage: F # write second row that contains two nulls with io.BytesIO() as f: write_dataset(client, f, [rows[1]], client.schema.get_table(table_name)["columns"]) - query = f.getvalue().decode() + query = f.getvalue() expect_load_file(client, file_storage, query, table_name) f_int_name = client.schema.naming.normalize_identifier("f_int") f_int_name_quoted = client.sql_client.escape_column_name(f_int_name) @@ -522,7 +521,7 @@ def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileS row["f_str"] = inj_str with io.BytesIO() as f: write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) - query = f.getvalue().decode() + query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] assert list(db_row) == list(row.values()) @@ -540,7 +539,7 @@ def test_data_writer_string_escape_edge( canonical_name = client.sql_client.make_qualified_table_name(table_name) with io.BytesIO() as f: write_dataset(client, f, rows, client.schema.get_table(table_name)["columns"]) - query = f.getvalue().decode() + query = f.getvalue() expect_load_file(client, file_storage, query, table_name) for i in range(1, len(rows) + 1): db_row = client.sql_client.execute_sql(f"SELECT str FROM {canonical_name} WHERE idx = {i}") @@ -562,11 +561,7 @@ def test_load_with_all_types( if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") table_name = "event_test_table" + uniq_id() - column_schemas, data_row = table_update_and_row( - exclude_types=( - ["time"] if client.config.destination_type in ["databricks", "clickhouse"] else None - ), - ) + column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) # we should have identical content with all disposition types partial = client.schema.update_table( @@ -595,9 +590,11 @@ def test_load_with_all_types( ): canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row + print(data_row) with io.BytesIO() as f: write_dataset(client, f, [data_row], column_schemas) - query = f.getvalue().decode() + query = f.getvalue() + print(client.schema.to_pretty_yaml()) expect_load_file(client, file_storage, query, table_name) db_row = list(client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0]) assert len(db_row) == len(data_row) @@ -636,13 +633,14 @@ def test_write_dispositions( os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy table_name = "event_test_table" + uniq_id() + column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) client.schema.update_table( - new_table(table_name, write_disposition=write_disposition, columns=TABLE_UPDATE) + new_table(table_name, write_disposition=write_disposition, columns=column_schemas.values()) ) child_table = client.schema.naming.make_path(table_name, "child") # add child table without write disposition so it will be inferred from the parent client.schema.update_table( - new_table(child_table, columns=TABLE_UPDATE, parent_table_name=table_name) + new_table(child_table, columns=column_schemas.values(), parent_table_name=table_name) ) client.schema._bump_version() client.update_stored_schema() @@ -663,11 +661,10 @@ def test_write_dispositions( for t in [table_name, child_table]: # write row, use col1 (INT) as row number - table_row = deepcopy(TABLE_ROW_ALL_DATA_TYPES) - table_row["col1"] = idx + data_row["col1"] = idx with io.BytesIO() as f: - write_dataset(client, f, [table_row], TABLE_UPDATE_COLUMNS_SCHEMA) - query = f.getvalue().decode() + write_dataset(client, f, [data_row], column_schemas) + query = f.getvalue() if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined] # load to staging dataset on merge with client.with_staging_dataset(): # type: ignore[attr-defined] @@ -707,7 +704,7 @@ def test_write_dispositions( @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) -def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> None: +def test_get_resumed_job(client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") user_table_name = prepare_table(client) @@ -715,19 +712,22 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No "_dlt_id": uniq_id(), "_dlt_root_id": uniq_id(), "sender_id": "90238094809sajlkjxoiewjhduuiuehd", - "timestamp": str(pendulum.now()), + "timestamp": pendulum.now(), } + print(client.schema.get_table(user_table_name)["columns"]) with io.BytesIO() as f: write_dataset(client, f, [load_json], client.schema.get_table(user_table_name)["columns"]) - dataset = f.getvalue().decode() + dataset = f.getvalue() job = expect_load_file(client, file_storage, dataset, user_table_name) # now try to retrieve the job # TODO: we should re-create client instance as this call is intended to be run after some disruption ie. stopped loader process - r_job = client.restore_file_load(file_storage.make_full_path(job.file_name())) - assert r_job.state() == "completed" - # use just file name to restore - r_job = client.restore_file_load(job.file_name()) - assert r_job.state() == "completed" + r_job = client.create_load_job( + client.schema.get_table(user_table_name), + file_storage.make_full_path(job.file_name()), + uniq_id(), + restore=True, + ) + assert r_job.state() == "ready" @pytest.mark.parametrize( @@ -808,7 +808,7 @@ def test_get_stored_state( with io.BytesIO() as f: # use normalized columns write_dataset(client, f, [norm_doc], partial["columns"]) - query = f.getvalue().decode() + query = f.getvalue() expect_load_file(client, file_storage, query, partial["name"]) client.complete_load("_load_id") @@ -831,12 +831,20 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: # "_dlt_load_id": "load_id", "event": "user", "sender_id": "sender_id", - "timestamp": str(pendulum.now()), + "timestamp": pendulum.now(), } with io.BytesIO() as f: - write_dataset(_client, f, [user_row], _client.schema.tables["event_user"]["columns"]) - query = f.getvalue().decode() - expect_load_file(_client, file_storage, query, "event_user") + write_dataset( + _client, + f, + [user_row], + _client.schema.tables["event_user"]["columns"], + file_format=destination_config.file_format, + ) + query = f.getvalue() + expect_load_file( + _client, file_storage, query, "event_user", file_format=destination_config.file_format + ) qual_table_name = _client.sql_client.make_qualified_table_name("event_user") db_rows = list(_client.sql_client.execute_sql(f"SELECT * FROM {qual_table_name}")) assert len(db_rows) == expected_rows @@ -889,6 +897,11 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: # no were detected - even if the schema is new. all the tables overlap and change in nullability does not do any updates assert schema_update == {} # 3 rows because we load to the same table + if ( + destination_config.file_format == "parquet" + or client.capabilities.preferred_loader_file_format == "parquet" + ): + event_3_schema.tables["event_user"]["columns"]["input_channel"]["nullable"] = True _load_something(client, 3) # adding new non null column will generate sync error, except for clickhouse, there it will work @@ -927,3 +940,13 @@ def normalize_rows(rows: List[Dict[str, Any]], naming: NamingConvention) -> None for row in rows: for k in list(row.keys()): row[naming.normalize_identifier(k)] = row.pop(k) + + +def get_columns_and_row_all_types(destination_type: str): + return table_update_and_row( + # TIME + parquet is actually a duckdb problem: https://github.com/duckdb/duckdb/pull/13283 + exclude_types=( + ["time"] if destination_type in ["databricks", "clickhouse", "motherduck"] else None + ), + exclude_columns=["col4_precision"] if destination_type in ["motherduck"] else None, + ) diff --git a/tests/load/test_jobs.py b/tests/load/test_jobs.py new file mode 100644 index 0000000000..69f5fb9ddc --- /dev/null +++ b/tests/load/test_jobs.py @@ -0,0 +1,75 @@ +import pytest + +from dlt.common.destination.reference import RunnableLoadJob +from dlt.common.destination.exceptions import DestinationTerminalException +from dlt.destinations.job_impl import FinalizedLoadJob + + +def test_instantiate_job() -> None: + file_name = "table.1234.0.jsonl" + file_path = "/path/" + file_name + + class SomeJob(RunnableLoadJob): + def run(self) -> None: + pass + + j = SomeJob(file_path) + assert j._file_name == file_name + assert j._file_path == file_path + + # providing only a filename is not allowed + with pytest.raises(AssertionError): + SomeJob(file_name) + + +def test_runnable_job_results() -> None: + file_path = "/table.1234.0.jsonl" + + class MockClient: + def prepare_load_job_execution(self, j: RunnableLoadJob): + pass + + class SuccessfulJob(RunnableLoadJob): + def run(self) -> None: + 5 + 5 + + j: RunnableLoadJob = SuccessfulJob(file_path) + assert j.state() == "ready" + j.run_managed(MockClient()) # type: ignore + assert j.state() == "completed" + + class RandomExceptionJob(RunnableLoadJob): + def run(self) -> None: + raise Exception("Oh no!") + + j = RandomExceptionJob(file_path) + assert j.state() == "ready" + j.run_managed(MockClient()) # type: ignore + assert j.state() == "retry" + assert j.exception() == "Oh no!" + + class TerminalJob(RunnableLoadJob): + def run(self) -> None: + raise DestinationTerminalException("Oh no!") + + j = TerminalJob(file_path) + assert j.state() == "ready" + j.run_managed(MockClient()) # type: ignore + assert j.state() == "failed" + assert j.exception() == "Oh no!" + + +def test_finalized_load_job() -> None: + file_name = "table.1234.0.jsonl" + file_path = "/path/" + file_name + j = FinalizedLoadJob(file_path) + assert j.state() == "completed" + assert not j.exception() + + j = FinalizedLoadJob(file_path, "failed", "oh no!") + assert j.state() == "failed" + assert j.exception() == "oh no!" + + # only actionable / terminal states are allowed + with pytest.raises(AssertionError): + FinalizedLoadJob(file_path, "ready") diff --git a/tests/load/test_parallelism_util.py b/tests/load/test_parallelism_util.py index b8f43d0743..3a7159563d 100644 --- a/tests/load/test_parallelism_util.py +++ b/tests/load/test_parallelism_util.py @@ -3,9 +3,9 @@ NOTE: there are tests in custom destination to check parallelism settings are applied """ -from typing import Tuple +from typing import Tuple, Any, cast -from dlt.load.utils import filter_new_jobs +from dlt.load.utils import filter_new_jobs, get_available_worker_slots from dlt.load.configuration import LoaderConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.utils import uniq_id @@ -21,24 +21,35 @@ def get_caps_conf() -> Tuple[DestinationCapabilitiesContext, LoaderConfiguration return DestinationCapabilitiesContext(), LoaderConfiguration() -def test_max_workers() -> None: - job_names = [create_job_name("t1", i) for i in range(100)] +def test_get_available_worker_slots() -> None: caps, conf = get_caps_conf() - # default is 20 - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + conf.workers = 20 + assert get_available_worker_slots(conf, caps, []) == 20 + + # change workers + conf.workers = 30 + assert get_available_worker_slots(conf, caps, []) == 30 + + # check with existing jobs + assert get_available_worker_slots(conf, caps, cast(Any, range(3))) == 27 + assert get_available_worker_slots(conf, caps, cast(Any, range(50))) == 0 + + # table-sequential will not change anything + caps.loader_parallelism_strategy = "table-sequential" + assert get_available_worker_slots(conf, caps, []) == 30 - # we can change it - conf.workers = 35 - assert len(filter_new_jobs(job_names, caps, conf)) == 35 + # caps with lower value will override + caps.max_parallel_load_jobs = 10 + assert get_available_worker_slots(conf, caps, []) == 10 - # destination may override this - caps.max_parallel_load_jobs = 15 - assert len(filter_new_jobs(job_names, caps, conf)) == 15 + # lower conf workers will override aing + conf.workers = 3 + assert get_available_worker_slots(conf, caps, []) == 3 - # lowest value will prevail - conf.workers = 5 - assert len(filter_new_jobs(job_names, caps, conf)) == 5 + # sequential strategy only allows one + caps.loader_parallelism_strategy = "sequential" + assert get_available_worker_slots(conf, caps, []) == 1 def test_table_sequential_parallelism_strategy() -> None: @@ -51,17 +62,16 @@ def test_table_sequential_parallelism_strategy() -> None: caps, conf = get_caps_conf() # default is 20 - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert len(filter_new_jobs(job_names, caps, conf, [], 20)) == 20 # table sequential will give us 8, one for each table conf.parallelism_strategy = "table-sequential" - filtered = filter_new_jobs(job_names, caps, conf) + filtered = filter_new_jobs(job_names, caps, conf, [], 20) assert len(filtered) == 8 assert len({ParsedLoadJobFileName.parse(j).table_name for j in job_names}) == 8 - # max workers also are still applied - conf.workers = 3 - assert len(filter_new_jobs(job_names, caps, conf)) == 3 + # only free available slots are also applied + assert len(filter_new_jobs(job_names, caps, conf, [], 3)) == 3 def test_strategy_preference() -> None: @@ -72,22 +82,37 @@ def test_strategy_preference() -> None: caps, conf = get_caps_conf() # nothing set will default to parallel - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 20 + ) caps.loader_parallelism_strategy = "table-sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 8 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 8 + ) caps.loader_parallelism_strategy = "sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 1 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 1 + ) # config may override (will go back to default 20) conf.parallelism_strategy = "parallel" - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 20 + ) conf.parallelism_strategy = "table-sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 8 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 8 + ) def test_no_input() -> None: caps, conf = get_caps_conf() - assert filter_new_jobs([], caps, conf) == [] + assert filter_new_jobs([], caps, conf, [], 50) == [] diff --git a/tests/load/utils.py b/tests/load/utils.py index 4b6c01c916..d649343c63 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -2,7 +2,20 @@ import contextlib import codecs import os -from typing import Any, Iterator, List, Sequence, IO, Tuple, Optional, Dict, Union, Generator, cast +from typing import ( + Any, + AnyStr, + Iterator, + List, + Sequence, + IO, + Tuple, + Optional, + Dict, + Union, + Generator, + cast, +) import shutil from pathlib import Path from urllib.parse import urlparse @@ -17,6 +30,7 @@ from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, JobClientBase, + RunnableLoadJob, LoadJob, DestinationClientStagingConfiguration, TDestinationReferenceArg, @@ -240,7 +254,8 @@ def destinations_configs( if destination not in ("athena", "synapse", "databricks", "dremio", "clickhouse") ] destination_configs += [ - DestinationTestConfiguration(destination="duckdb", file_format="parquet") + DestinationTestConfiguration(destination="duckdb", file_format="parquet"), + DestinationTestConfiguration(destination="motherduck", file_format="insert_values"), ] # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. destination_configs += [ @@ -682,23 +697,31 @@ def load_table(name: str) -> Dict[str, TTableSchemaColumns]: def expect_load_file( client: JobClientBase, file_storage: FileStorage, - query: str, + query: AnyStr, table_name: str, status="completed", + file_format: TLoaderFileFormat = None, ) -> LoadJob: file_name = ParsedLoadJobFileName( table_name, ParsedLoadJobFileName.new_file_id(), 0, - client.capabilities.preferred_loader_file_format, + file_format or client.capabilities.preferred_loader_file_format, ).file_name() - file_storage.save(file_name, query.encode("utf-8")) + if isinstance(query, str): + query = query.encode("utf-8") # type: ignore[assignment] + file_storage.save(file_name, query) table = client.prepare_load_table(table_name) - job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) + load_id = uniq_id() + job = client.create_load_job(table, file_storage.make_full_path(file_name), load_id) + + if isinstance(job, RunnableLoadJob): + job.set_run_vars(load_id=load_id, schema=client.schema, load_table=table) + job.run_managed(client) while job.state() == "running": sleep(0.5) assert job.file_name() == file_name - assert job.state() == status + assert job.state() == status, f"Got {job.state()} with ({job.exception()})" return job @@ -824,16 +847,15 @@ def write_dataset( f: IO[bytes], rows: Union[List[Dict[str, Any]], List[StrAny]], columns_schema: TTableSchemaColumns, + file_format: TLoaderFileFormat = None, ) -> None: spec = DataWriter.writer_spec_from_file_format( - client.capabilities.preferred_loader_file_format, "object" + file_format or client.capabilities.preferred_loader_file_format, "object" ) # adapt bytes stream to text file format if not spec.is_binary_format and isinstance(f.read(0), bytes): f = codecs.getwriter("utf-8")(f) # type: ignore[assignment] - writer = DataWriter.from_file_format( - client.capabilities.preferred_loader_file_format, "object", f, client.capabilities - ) + writer = DataWriter.from_file_format(spec.file_format, "object", f, client.capabilities) # remove None values for idx, row in enumerate(rows): rows[idx] = {k: v for k, v in row.items() if v is not None} @@ -842,18 +864,37 @@ def write_dataset( def prepare_load_package( - load_storage: LoadStorage, cases: Sequence[str], write_disposition: str = "append" + load_storage: LoadStorage, + cases: Sequence[str], + write_disposition: str = "append", + jobs_per_case: int = 1, ) -> Tuple[str, Schema]: + """ + Create a load package with explicitely provided files + job_per_case multiplies the amount of load jobs, for big packages use small files + """ load_id = uniq_id() load_storage.new_packages.create_package(load_id) for case in cases: path = f"./tests/load/cases/loading/{case}" - shutil.copy( - path, - load_storage.new_packages.storage.make_full_path( + for _ in range(jobs_per_case): + new_path = load_storage.new_packages.storage.make_full_path( load_storage.new_packages.get_job_state_folder_path(load_id, "new_jobs") - ), - ) + ) + shutil.copy( + path, + new_path, + ) + if jobs_per_case > 1: + parsed_name = ParsedLoadJobFileName.parse(case) + new_file_name = ParsedLoadJobFileName( + parsed_name.table_name, + ParsedLoadJobFileName.new_file_id(), + 0, + parsed_name.file_format, + ).file_name() + shutil.move(new_path + "/" + case, new_path + "/" + new_file_name) + schema_path = Path("./tests/load/cases/loading/schema.json") # load without migration data = json.loads(schema_path.read_text(encoding="utf8")) diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index dc2110d2f6..0a249db0fd 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -192,8 +192,8 @@ def test_load_case_sensitive_data(client: WeaviateClient, file_storage: FileStor write_dataset(client, f, [data_clash], table_create) query = f.getvalue().decode() class_name = client.schema.naming.normalize_table_identifier(class_name) - with pytest.raises(PropertyNameConflict): - expect_load_file(client, file_storage, query, class_name) + job = expect_load_file(client, file_storage, query, class_name, "failed") + assert type(job._exception) is PropertyNameConflict # type: ignore def test_load_case_sensitive_data_ci(ci_client: WeaviateClient, file_storage: FileStorage) -> None: diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 7c7dac8e71..0ab1f61d72 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -5,12 +5,14 @@ import logging import os import random +import shutil import threading from time import sleep from typing import Any, List, Tuple, cast from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest +from dlt.common.storages import FileStorage import dlt from dlt.common import json, pendulum @@ -1759,11 +1761,24 @@ def test_remove_pending_packages() -> None: assert pipeline.has_pending_data is False # partial load os.environ["EXCEPTION_PROB"] = "1.0" - os.environ["FAIL_IN_INIT"] = "False" os.environ["TIMEOUT"] = "1.0" - # should produce partial loads + # will make job go into retry state with pytest.raises(PipelineStepFailed): pipeline.run(airtable_emojis()) + # move job into completed folder manually to simulate partial package + load_storage = pipeline._get_load_storage() + load_id = load_storage.normalized_packages.list_packages()[0] + job = load_storage.normalized_packages.list_new_jobs(load_id)[0] + started_path = load_storage.normalized_packages.start_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) + completed_path = load_storage.normalized_packages.complete_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) + # to test partial loads we need two jobs one completed an one in another state + # to simulate this, we just duplicate the completed job into the started path + shutil.copyfile(completed_path, started_path) + # now "with partial loads" can be tested assert pipeline.has_pending_data pipeline.drop_pending_packages(with_partial_loads=False) assert pipeline.has_pending_data diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index bdb3e3eb22..3239e01bab 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -46,7 +46,7 @@ def inject_tomls( ): @dlt.resource(write_disposition="replace", primary_key="id") def data(): - yield [1, 2, 3] + yield [{"id": 1}, {"id": 2}, {"id": 3}] return data() @@ -362,12 +362,12 @@ def test_trace_telemetry() -> None: with patch("dlt.common.runtime.sentry.before_send", _mock_sentry_before_send), patch( "dlt.common.runtime.anon_tracker.before_send", _mock_anon_tracker_before_send ): - # os.environ["FAIL_PROB"] = "1.0" # make it complete immediately start_test_telemetry() ANON_TRACKER_SENT_ITEMS.clear() SENTRY_SENT_ITEMS.clear() - # default dummy fails all files + # make dummy fail all files + os.environ["FAIL_PROB"] = "1.0" load_info = dlt.pipeline().run( [1, 2, 3], table_name="data", destination="dummy", dataset_name="data_data" ) @@ -397,6 +397,11 @@ def test_trace_telemetry() -> None: # dummy has empty fingerprint assert event["properties"]["destination_fingerprint"] == "" # we have two failed files (state and data) that should be logged by sentry + # TODO: make this work + print(SENTRY_SENT_ITEMS) + for item in SENTRY_SENT_ITEMS: + # print(item) + print(item["logentry"]["message"]) assert len(SENTRY_SENT_ITEMS) == 2 # trace with exception diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index bd62f76dc1..dfdb9c8e40 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -172,16 +172,6 @@ def _load_file(client: FSClientBase, filepath) -> List[Dict[str, Any]]: # # Load table dicts # -def _get_delta_table(client: FilesystemClient, table_name: str) -> "DeltaTable": # type: ignore[name-defined] # noqa: F821 - from deltalake import DeltaTable - from dlt.common.libs.deltalake import _deltalake_storage_options - - table_dir = client.get_table_dir(table_name) - remote_table_dir = f"{client.config.protocol}://{table_dir}" - return DeltaTable( - remote_table_dir, - storage_options=_deltalake_storage_options(client.config), - ) def _load_tables_to_dicts_fs(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, Any]]]: @@ -189,13 +179,20 @@ def _load_tables_to_dicts_fs(p: dlt.Pipeline, *table_names: str) -> Dict[str, Li client = p._fs_client() result: Dict[str, Any] = {} + delta_table_names = [ + table_name + for table_name in table_names + if get_table_format(p.default_schema.tables, table_name) == "delta" + ] + if len(delta_table_names) > 0: + from dlt.common.libs.deltalake import get_delta_tables + + delta_tables = get_delta_tables(p, *table_names) + for table_name in table_names: - if ( - table_name in p.default_schema.data_table_names() - and get_table_format(p.default_schema.tables, table_name) == "delta" - ): + if table_name in p.default_schema.data_table_names() and table_name in delta_table_names: assert isinstance(client, FilesystemClient) - dt = _get_delta_table(client, table_name) + dt = delta_tables[table_name] result[table_name] = dt.to_pyarrow_table().to_pylist() else: table_files = client.list_table_files(table_name) diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index f5de1ec5da..af914bf89d 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -400,7 +400,7 @@ def test_paginate_json_body_without_params(self, rest_client) -> None: posts_skip = (DEFAULT_TOTAL_PAGES - 3) * DEFAULT_PAGE_SIZE class JSONBodyPageCursorPaginator(BaseReferencePaginator): - def update_state(self, response): + def update_state(self, response, data): self._next_reference = response.json().get("next_page") def update_request(self, request): diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 8a3c136e09..39e3d767a0 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -1,3 +1,4 @@ +from typing import Any, List from unittest.mock import Mock import pytest @@ -16,6 +17,8 @@ from .conftest import assert_pagination +NON_EMPTY_PAGE = [{"some": "data"}] + @pytest.mark.usefixtures("mock_api_server") class TestHeaderLinkPaginator: @@ -241,7 +244,7 @@ class TestOffsetPaginator: def test_update_state(self): paginator = OffsetPaginator(offset=0, limit=10) response = Mock(Response, json=lambda: {"total": 20}) - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) assert paginator.current_value == 10 assert paginator.has_next_page is True @@ -252,7 +255,7 @@ def test_update_state(self): def test_update_state_with_string_total(self): paginator = OffsetPaginator(0, 10) response = Mock(Response, json=lambda: {"total": "20"}) - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) assert paginator.current_value == 10 assert paginator.has_next_page is True @@ -260,13 +263,13 @@ def test_update_state_with_invalid_total(self): paginator = OffsetPaginator(0, 10) response = Mock(Response, json=lambda: {"total": "invalid"}) with pytest.raises(ValueError): - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) def test_update_state_without_total(self): paginator = OffsetPaginator(0, 10) response = Mock(Response, json=lambda: {}) with pytest.raises(ValueError): - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) def test_init_request(self): paginator = OffsetPaginator(offset=123, limit=42) @@ -280,7 +283,7 @@ def test_init_request(self): response = Mock(Response, json=lambda: {"total": 200}) - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) # Test for the next request next_request = Mock(spec=Request) @@ -294,11 +297,11 @@ def test_init_request(self): def test_maximum_offset(self): paginator = OffsetPaginator(offset=0, limit=50, maximum_offset=100, total_path=None) response = Mock(Response, json=lambda: {"items": []}) - paginator.update_state(response) # Offset 0 to 50 + paginator.update_state(response, data=NON_EMPTY_PAGE) # Offset 0 to 50 assert paginator.current_value == 50 assert paginator.has_next_page is True - paginator.update_state(response) # Offset 50 to 100 + paginator.update_state(response, data=NON_EMPTY_PAGE) # Offset 50 to 100 assert paginator.current_value == 100 assert paginator.has_next_page is False @@ -312,28 +315,75 @@ def test_client_pagination(self, rest_client): assert_pagination(pages) + def test_stop_after_empty_page(self): + paginator = OffsetPaginator( + offset=0, + limit=50, + maximum_offset=100, + total_path=None, + stop_after_empty_page=True, + ) + response = Mock(Response, json=lambda: {"items": []}) + no_data_found: List[Any] = [] + paginator.update_state(response, no_data_found) # Page 1 + assert paginator.has_next_page is False + + def test_guarantee_termination(self): + OffsetPaginator( + limit=10, + total_path=None, + ) + + OffsetPaginator( + limit=10, + total_path=None, + maximum_offset=1, + stop_after_empty_page=False, + ) + + with pytest.raises(ValueError) as e: + OffsetPaginator( + limit=10, + total_path=None, + stop_after_empty_page=False, + ) + assert e.match( + "`total_path` or `maximum_offset` or `stop_after_empty_page` must be provided" + ) + + with pytest.raises(ValueError) as e: + OffsetPaginator( + limit=10, + total_path=None, + stop_after_empty_page=False, + maximum_offset=None, + ) + assert e.match( + "`total_path` or `maximum_offset` or `stop_after_empty_page` must be provided" + ) + @pytest.mark.usefixtures("mock_api_server") class TestPageNumberPaginator: def test_update_state(self): paginator = PageNumberPaginator(base_page=1, page=1, total_path="total_pages") response = Mock(Response, json=lambda: {"total_pages": 3}) - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) assert paginator.current_value == 2 assert paginator.has_next_page is True - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) assert paginator.current_value == 3 assert paginator.has_next_page is True # Test for reaching the end - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) assert paginator.has_next_page is False def test_update_state_with_string_total_pages(self): paginator = PageNumberPaginator(base_page=1, page=1) response = Mock(Response, json=lambda: {"total": "3"}) - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) assert paginator.current_value == 2 assert paginator.has_next_page is True @@ -341,37 +391,52 @@ def test_update_state_with_invalid_total_pages(self): paginator = PageNumberPaginator(base_page=1, page=1) response = Mock(Response, json=lambda: {"total_pages": "invalid"}) with pytest.raises(ValueError): - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) def test_update_state_without_total_pages(self): paginator = PageNumberPaginator(base_page=1, page=1) response = Mock(Response, json=lambda: {}) with pytest.raises(ValueError): - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) def test_update_request(self): paginator = PageNumberPaginator(base_page=1, page=1, page_param="page") request = Mock(Request) response = Mock(Response, json=lambda: {"total": 3}) - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) request.params = {} paginator.update_request(request) assert request.params["page"] == 2 - paginator.update_state(response) + paginator.update_state(response, data=NON_EMPTY_PAGE) paginator.update_request(request) assert request.params["page"] == 3 def test_maximum_page(self): paginator = PageNumberPaginator(base_page=1, page=1, maximum_page=3, total_path=None) response = Mock(Response, json=lambda: {"items": []}) - paginator.update_state(response) # Page 1 + paginator.update_state(response, data=NON_EMPTY_PAGE) # Page 1 assert paginator.current_value == 2 assert paginator.has_next_page is True - paginator.update_state(response) # Page 2 + paginator.update_state(response, data=NON_EMPTY_PAGE) # Page 2 assert paginator.current_value == 3 assert paginator.has_next_page is False + def test_stop_after_empty_page(self): + paginator = PageNumberPaginator( + base_page=1, + page=1, + maximum_page=5, + stop_after_empty_page=True, + total_path=None, + ) + response = Mock(Response, json=lambda: {"items": []}) + no_data_found: List[Any] = [] + assert paginator.has_next_page is True + paginator.update_state(response, no_data_found) + assert paginator.current_value == 1 + assert paginator.has_next_page is False + def test_client_pagination_one_based(self, rest_client): pages_iter = rest_client.paginate( "/posts", @@ -402,6 +467,32 @@ def test_client_pagination_zero_based(self, rest_client): assert_pagination(pages) + def test_guarantee_termination(self): + PageNumberPaginator( + total_path=None, + ) + + PageNumberPaginator( + total_path=None, + maximum_page=1, + stop_after_empty_page=False, + ) + + with pytest.raises(ValueError) as e: + PageNumberPaginator( + total_path=None, + stop_after_empty_page=False, + ) + assert e.match("`total_path` or `maximum_page` or `stop_after_empty_page` must be provided") + + with pytest.raises(ValueError) as e: + PageNumberPaginator( + total_path=None, + stop_after_empty_page=False, + maximum_page=None, + ) + assert e.match("`total_path` or `maximum_page` or `stop_after_empty_page` must be provided") + @pytest.mark.usefixtures("mock_api_server") class TestJSONResponseCursorPaginator: diff --git a/tests/utils.py b/tests/utils.py index bf3aafdb77..976a623c0b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -56,7 +56,6 @@ "filesystem", "weaviate", "dummy", - "motherduck", "qdrant", "lancedb", "destination",