diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index 84a8f95d71..a9306c2f9c 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -79,9 +79,6 @@ jobs: - name: Install dependencies run: poetry install --no-interaction -E redshift -E postgis -E postgres -E gs -E s3 -E az -E parquet -E duckdb -E cli -E filesystem --with sentry-sdk --with pipeline,ibis -E deltalake -E pyiceberg - - name: enable certificates for azure and duckdb - run: sudo mkdir -p /etc/pki/tls/certs && sudo ln -s /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt - - name: Upgrade sqlalchemy run: poetry run pip install sqlalchemy==2.0.18 # minimum version required by `pyiceberg` diff --git a/dlt/cli/command_wrappers.py b/dlt/cli/command_wrappers.py index 0e6491688e..847b5daabb 100644 --- a/dlt/cli/command_wrappers.py +++ b/dlt/cli/command_wrappers.py @@ -43,14 +43,14 @@ def init_command_wrapper( destination_type: str, repo_location: str, branch: str, - omit_core_sources: bool = False, + eject_source: bool = False, ) -> None: init_command( source_name, destination_type, repo_location, branch, - omit_core_sources, + eject_source, ) diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index ac8adcc588..e81fa80c36 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -157,7 +157,7 @@ def _list_core_sources() -> Dict[str, SourceConfiguration]: sources: Dict[str, SourceConfiguration] = {} for source_name in files_ops.get_sources_names(core_sources_storage, source_type="core"): sources[source_name] = files_ops.get_core_source_configuration( - core_sources_storage, source_name + core_sources_storage, source_name, eject_source=False ) return sources @@ -295,7 +295,7 @@ def init_command( destination_type: str, repo_location: str, branch: str = None, - omit_core_sources: bool = False, + eject_source: bool = False, ) -> None: # try to import the destination and get config spec destination_reference = Destination.from_reference(destination_type) @@ -310,13 +310,9 @@ def init_command( # discover type of source source_type: files_ops.TSourceType = "template" - if ( - source_name in files_ops.get_sources_names(core_sources_storage, source_type="core") - ) and not omit_core_sources: + if source_name in files_ops.get_sources_names(core_sources_storage, source_type="core"): source_type = "core" else: - if omit_core_sources: - fmt.echo("Omitting dlt core sources.") verified_sources_storage = _clone_and_get_verified_sources_storage(repo_location, branch) if source_name in files_ops.get_sources_names( verified_sources_storage, source_type="verified" @@ -380,7 +376,7 @@ def init_command( else: if source_type == "core": source_configuration = files_ops.get_core_source_configuration( - core_sources_storage, source_name + core_sources_storage, source_name, eject_source ) from importlib.metadata import Distribution @@ -392,6 +388,9 @@ def init_command( if canonical_source_name in extras: source_configuration.requirements.update_dlt_extras(canonical_source_name) + + # create remote modified index to copy files when ejecting + remote_modified = {file_name: None for file_name in source_configuration.files} else: if not is_valid_schema_name(source_name): raise InvalidSchemaName(source_name) @@ -536,11 +535,17 @@ def init_command( "Creating a new pipeline with the dlt core source %s (%s)" % (fmt.bold(source_name), source_configuration.doc) ) - fmt.echo( - "NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from the" - " verified sources repo but imported from dlt.sources. You can provide the" - " --omit-core-sources flag to revert to the old behavior." % (fmt.bold(source_name)) - ) + if eject_source: + fmt.echo( + "NOTE: Source code of %s will be ejected. Remember to modify the pipeline " + "example script to import the ejected source." % (fmt.bold(source_name)) + ) + else: + fmt.echo( + "NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from" + " the verified sources repo but imported from dlt.sources. You can provide the" + " --eject flag to revert to the old behavior." % (fmt.bold(source_name)) + ) elif source_configuration.source_type == "verified": fmt.echo( "Creating and configuring a new pipeline with the verified source %s (%s)" diff --git a/dlt/cli/pipeline_files.py b/dlt/cli/pipeline_files.py index b6f8f85271..c0139fe2a7 100644 --- a/dlt/cli/pipeline_files.py +++ b/dlt/cli/pipeline_files.py @@ -226,11 +226,31 @@ def get_template_configuration( ) +def _get_source_files(sources_storage: FileStorage, source_name: str) -> List[str]: + """Get all files that belong to source `source_name`""" + files: List[str] = [] + for root, subdirs, _files in os.walk(sources_storage.make_full_path(source_name)): + # filter unwanted files + for subdir in list(subdirs): + if any(fnmatch.fnmatch(subdir, ignore) for ignore in IGNORE_FILES): + subdirs.remove(subdir) + rel_root = sources_storage.to_relative_path(root) + files.extend( + [ + os.path.join(rel_root, file) + for file in _files + if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES) + ] + ) + return files + + def get_core_source_configuration( - sources_storage: FileStorage, source_name: str + sources_storage: FileStorage, source_name: str, eject_source: bool ) -> SourceConfiguration: src_pipeline_file = CORE_SOURCE_TEMPLATE_MODULE_NAME + "/" + source_name + PIPELINE_FILE_SUFFIX dest_pipeline_file = source_name + PIPELINE_FILE_SUFFIX + files: List[str] = _get_source_files(sources_storage, source_name) if eject_source else [] return SourceConfiguration( "core", @@ -238,7 +258,7 @@ def get_core_source_configuration( sources_storage, src_pipeline_file, dest_pipeline_file, - [".gitignore"], + files, SourceRequirements([]), _get_docstring_for_module(sources_storage, source_name), False, @@ -259,21 +279,7 @@ def get_verified_source_configuration( f"Pipeline example script {example_script} could not be found in the repository", source_name, ) - # get all files recursively - files: List[str] = [] - for root, subdirs, _files in os.walk(sources_storage.make_full_path(source_name)): - # filter unwanted files - for subdir in list(subdirs): - if any(fnmatch.fnmatch(subdir, ignore) for ignore in IGNORE_FILES): - subdirs.remove(subdir) - rel_root = sources_storage.to_relative_path(root) - files.extend( - [ - os.path.join(rel_root, file) - for file in _files - if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES) - ] - ) + files = _get_source_files(sources_storage, source_name) # read requirements requirements_path = os.path.join(source_name, utils.REQUIREMENTS_TXT) if sources_storage.has_file(requirements_path): diff --git a/dlt/cli/plugins.py b/dlt/cli/plugins.py index cc2d4594b9..1712efbbd7 100644 --- a/dlt/cli/plugins.py +++ b/dlt/cli/plugins.py @@ -84,14 +84,10 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None: ) parser.add_argument( - "--omit-core-sources", + "--eject", default=False, action="store_true", - help=( - "When present, will not create the new pipeline with a core source of the given" - " name but will take a source of this name from the default or provided" - " location." - ), + help="Ejects the source code of the core source like sql_database", ) def execute(self, args: argparse.Namespace) -> None: @@ -107,7 +103,7 @@ def execute(self, args: argparse.Namespace) -> None: args.destination, args.location, args.branch, - args.omit_core_sources, + args.eject, ) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 038abdc4d0..4f9e0eb42e 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -457,16 +457,8 @@ def diff_table( * when columns with the same name have different data types * when table links to different parent tables """ - if tab_a["name"] != tab_b["name"]: - raise TablePropertiesConflictException( - schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"] - ) - table_name = tab_a["name"] - # check if table properties can be merged - if tab_a.get("parent") != tab_b.get("parent"): - raise TablePropertiesConflictException( - schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent") - ) + # allow for columns to differ + ensure_compatible_tables(schema_name, tab_a, tab_b, ensure_columns=False) # get new columns, changes in the column data type or other properties are not allowed tab_a_columns = tab_a["columns"] @@ -474,18 +466,6 @@ def diff_table( for col_b_name, col_b in tab_b["columns"].items(): if col_b_name in tab_a_columns: col_a = tab_a_columns[col_b_name] - # we do not support changing data types of columns - if is_complete_column(col_a) and is_complete_column(col_b): - if not compare_complete_columns(tab_a_columns[col_b_name], col_b): - # attempt to update to incompatible columns - raise CannotCoerceColumnException( - schema_name, - table_name, - col_b_name, - col_b["data_type"], - tab_a_columns[col_b_name]["data_type"], - None, - ) # all other properties can change merged_column = merge_column(copy(col_a), col_b) if merged_column != col_a: @@ -494,6 +474,8 @@ def diff_table( new_columns.append(col_b) # return partial table containing only name and properties that differ (column, filters etc.) + table_name = tab_a["name"] + partial_table: TPartialTableSchema = { "name": table_name, "columns": {} if new_columns is None else {c["name"]: c for c in new_columns}, @@ -519,6 +501,50 @@ def diff_table( return partial_table +def ensure_compatible_tables( + schema_name: str, tab_a: TTableSchema, tab_b: TPartialTableSchema, ensure_columns: bool = True +) -> None: + """Ensures that `tab_a` and `tab_b` can be merged without conflicts. Conflicts are detected when + + - tables have different names + - nested tables have different parents + - tables have any column with incompatible types + + Note: all the identifiers must be already normalized + + """ + if tab_a["name"] != tab_b["name"]: + raise TablePropertiesConflictException( + schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"] + ) + table_name = tab_a["name"] + # check if table properties can be merged + if tab_a.get("parent") != tab_b.get("parent"): + raise TablePropertiesConflictException( + schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent") + ) + + if not ensure_columns: + return + + tab_a_columns = tab_a["columns"] + for col_b_name, col_b in tab_b["columns"].items(): + if col_b_name in tab_a_columns: + col_a = tab_a_columns[col_b_name] + # we do not support changing data types of columns + if is_complete_column(col_a) and is_complete_column(col_b): + if not compare_complete_columns(tab_a_columns[col_b_name], col_b): + # attempt to update to incompatible columns + raise CannotCoerceColumnException( + schema_name, + table_name, + col_b_name, + col_b["data_type"], + tab_a_columns[col_b_name]["data_type"], + None, + ) + + # def compare_tables(tab_a: TTableSchema, tab_b: TTableSchema) -> bool: # try: # table_name = tab_a["name"] diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index f4d2b1f302..e832833428 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -3,7 +3,6 @@ from dlt.common.exceptions import DltException from dlt.common.utils import get_callable_name -from dlt.extract.items import ValidateItem, TDataItems class ExtractorException(DltException): diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 000e5c4cdb..22a0062acf 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -37,7 +37,8 @@ InconsistentTableTemplate, ) from dlt.extract.incremental import Incremental, TIncrementalConfig -from dlt.extract.items import TFunHintTemplate, TTableHintTemplate, TableNameMeta, ValidateItem +from dlt.extract.items import TFunHintTemplate, TTableHintTemplate, TableNameMeta +from dlt.extract.items_transform import ValidateItem from dlt.extract.utils import ensure_table_schema_columns, ensure_table_schema_columns_hint from dlt.extract.validation import create_item_validator diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index 5e7bae49c6..ce06292864 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -44,7 +44,8 @@ IncrementalArgs, TIncrementalRange, ) -from dlt.extract.items import SupportsPipe, TTableHintTemplate, ItemTransform +from dlt.extract.items import SupportsPipe, TTableHintTemplate +from dlt.extract.items_transform import ItemTransform from dlt.extract.incremental.transform import ( JsonIncremental, ArrowIncremental, diff --git a/dlt/extract/items.py b/dlt/extract/items.py index 888787e6b7..ad7447c163 100644 --- a/dlt/extract/items.py +++ b/dlt/extract/items.py @@ -1,21 +1,16 @@ -import inspect from abc import ABC, abstractmethod from typing import ( Any, Callable, - ClassVar, - Generic, Iterator, Iterable, Literal, Optional, Protocol, - TypeVar, Union, Awaitable, TYPE_CHECKING, NamedTuple, - Generator, ) from concurrent.futures import Future @@ -28,7 +23,6 @@ TDynHintType, ) - TDecompositionStrategy = Literal["none", "scc"] TDeferredDataItems = Callable[[], TDataItems] TAwaitableDataItems = Awaitable[TDataItems] @@ -113,6 +107,10 @@ def gen(self) -> TPipeStep: """A data generating step""" ... + def replace_gen(self, gen: TPipeStep) -> None: + """Replaces data generating step. Assumes that you know what are you doing""" + ... + def __getitem__(self, i: int) -> TPipeStep: """Get pipe step at index""" ... @@ -129,112 +127,3 @@ def has_parent(self) -> bool: def close(self) -> None: """Closes pipe generator""" ... - - -ItemTransformFunctionWithMeta = Callable[[TDataItem, str], TAny] -ItemTransformFunctionNoMeta = Callable[[TDataItem], TAny] -ItemTransformFunc = Union[ItemTransformFunctionWithMeta[TAny], ItemTransformFunctionNoMeta[TAny]] - - -class ItemTransform(ABC, Generic[TAny]): - _f_meta: ItemTransformFunctionWithMeta[TAny] = None - _f: ItemTransformFunctionNoMeta[TAny] = None - - placement_affinity: ClassVar[float] = 0 - """Tell how strongly an item sticks to start (-1) or end (+1) of pipe.""" - - def __init__(self, transform_f: ItemTransformFunc[TAny]) -> None: - # inspect the signature - sig = inspect.signature(transform_f) - # TODO: use TypeGuard here to get rid of type ignore - if len(sig.parameters) == 1: - self._f = transform_f # type: ignore - else: # TODO: do better check - self._f_meta = transform_f # type: ignore - - def bind(self: "ItemTransform[TAny]", pipe: SupportsPipe) -> "ItemTransform[TAny]": - return self - - @abstractmethod - def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: - """Transforms `item` (a list of TDataItem or a single TDataItem) and returns or yields TDataItems. Returns None to consume item (filter out)""" - pass - - -class FilterItem(ItemTransform[bool]): - # mypy needs those to type correctly - _f_meta: ItemTransformFunctionWithMeta[bool] - _f: ItemTransformFunctionNoMeta[bool] - - def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: - if isinstance(item, list): - # preserve empty lists - if len(item) == 0: - return item - - if self._f_meta: - item = [i for i in item if self._f_meta(i, meta)] - else: - item = [i for i in item if self._f(i)] - if not item: - # item was fully consumed by the filter - return None - return item - else: - if self._f_meta: - return item if self._f_meta(item, meta) else None - else: - return item if self._f(item) else None - - -class MapItem(ItemTransform[TDataItem]): - # mypy needs those to type correctly - _f_meta: ItemTransformFunctionWithMeta[TDataItem] - _f: ItemTransformFunctionNoMeta[TDataItem] - - def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: - if isinstance(item, list): - if self._f_meta: - return [self._f_meta(i, meta) for i in item] - else: - return [self._f(i) for i in item] - else: - if self._f_meta: - return self._f_meta(item, meta) - else: - return self._f(item) - - -class YieldMapItem(ItemTransform[Iterator[TDataItem]]): - # mypy needs those to type correctly - _f_meta: ItemTransformFunctionWithMeta[TDataItem] - _f: ItemTransformFunctionNoMeta[TDataItem] - - def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: - if isinstance(item, list): - for i in item: - if self._f_meta: - yield from self._f_meta(i, meta) - else: - yield from self._f(i) - else: - if self._f_meta: - yield from self._f_meta(item, meta) - else: - yield from self._f(item) - - -class ValidateItem(ItemTransform[TDataItem]): - """Base class for validators of data items. - - Subclass should implement the `__call__` method to either return the data item(s) or raise `extract.exceptions.ValidationError`. - See `PydanticValidator` for possible implementation. - """ - - placement_affinity: ClassVar[float] = 0.9 # stick to end but less than incremental - - table_name: str - - def bind(self, pipe: SupportsPipe) -> ItemTransform[TDataItem]: - self.table_name = pipe.name - return self diff --git a/dlt/extract/items_transform.py b/dlt/extract/items_transform.py new file mode 100644 index 0000000000..12375640bc --- /dev/null +++ b/dlt/extract/items_transform.py @@ -0,0 +1,179 @@ +import inspect +import time + +from abc import ABC, abstractmethod +from typing import ( + Any, + Callable, + ClassVar, + Generic, + Iterator, + Optional, + Union, +) +from concurrent.futures import Future + +from dlt.common.typing import ( + TAny, + TDataItem, + TDataItems, +) + +from dlt.extract.utils import ( + wrap_iterator, +) + +from dlt.extract.items import SupportsPipe + + +ItemTransformFunctionWithMeta = Callable[[TDataItem, str], TAny] +ItemTransformFunctionNoMeta = Callable[[TDataItem], TAny] +ItemTransformFunc = Union[ItemTransformFunctionWithMeta[TAny], ItemTransformFunctionNoMeta[TAny]] + + +class ItemTransform(ABC, Generic[TAny]): + _f_meta: ItemTransformFunctionWithMeta[TAny] = None + _f: ItemTransformFunctionNoMeta[TAny] = None + + placement_affinity: ClassVar[float] = 0 + """Tell how strongly an item sticks to start (-1) or end (+1) of pipe.""" + + def __init__(self, transform_f: ItemTransformFunc[TAny]) -> None: + # inspect the signature + sig = inspect.signature(transform_f) + # TODO: use TypeGuard here to get rid of type ignore + if len(sig.parameters) == 1: + self._f = transform_f # type: ignore + else: # TODO: do better check + self._f_meta = transform_f # type: ignore + + def bind(self: "ItemTransform[TAny]", pipe: SupportsPipe) -> "ItemTransform[TAny]": + return self + + @abstractmethod + def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: + """Transforms `item` (a list of TDataItem or a single TDataItem) and returns or yields TDataItems. Returns None to consume item (filter out)""" + pass + + +class FilterItem(ItemTransform[bool]): + # mypy needs those to type correctly + _f_meta: ItemTransformFunctionWithMeta[bool] + _f: ItemTransformFunctionNoMeta[bool] + + def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: + if isinstance(item, list): + # preserve empty lists + if len(item) == 0: + return item + + if self._f_meta: + item = [i for i in item if self._f_meta(i, meta)] + else: + item = [i for i in item if self._f(i)] + if not item: + # item was fully consumed by the filter + return None + return item + else: + if self._f_meta: + return item if self._f_meta(item, meta) else None + else: + return item if self._f(item) else None + + +class MapItem(ItemTransform[TDataItem]): + # mypy needs those to type correctly + _f_meta: ItemTransformFunctionWithMeta[TDataItem] + _f: ItemTransformFunctionNoMeta[TDataItem] + + def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: + if isinstance(item, list): + if self._f_meta: + return [self._f_meta(i, meta) for i in item] + else: + return [self._f(i) for i in item] + else: + if self._f_meta: + return self._f_meta(item, meta) + else: + return self._f(item) + + +class YieldMapItem(ItemTransform[Iterator[TDataItem]]): + # mypy needs those to type correctly + _f_meta: ItemTransformFunctionWithMeta[TDataItem] + _f: ItemTransformFunctionNoMeta[TDataItem] + + def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: + if isinstance(item, list): + for i in item: + if self._f_meta: + yield from self._f_meta(i, meta) + else: + yield from self._f(i) + else: + if self._f_meta: + yield from self._f_meta(item, meta) + else: + yield from self._f(item) + + +class ValidateItem(ItemTransform[TDataItem]): + """Base class for validators of data items. + + Subclass should implement the `__call__` method to either return the data item(s) or raise `extract.exceptions.ValidationError`. + See `PydanticValidator` for possible implementation. + """ + + placement_affinity: ClassVar[float] = 0.9 # stick to end but less than incremental + + table_name: str + + def bind(self, pipe: SupportsPipe) -> ItemTransform[TDataItem]: + self.table_name = pipe.name + return self + + +class LimitItem(ItemTransform[TDataItem]): + placement_affinity: ClassVar[float] = 1.1 # stick to end right behind incremental + + def __init__(self, max_items: Optional[int], max_time: Optional[float]) -> None: + self.max_items = max_items if max_items is not None else -1 + self.max_time = max_time + + def bind(self, pipe: SupportsPipe) -> "LimitItem": + # we also wrap iterators to make them stoppable + if isinstance(pipe.gen, Iterator): + pipe.replace_gen(wrap_iterator(pipe.gen)) + + self.gen = pipe.gen + self.count = 0 + self.exhausted = False + self.start_time = time.time() + + return self + + def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: + self.count += 1 + + # detect when the limit is reached, max time or yield count + if ( + (self.count == self.max_items) + or (self.max_time and time.time() - self.start_time > self.max_time) + or self.max_items == 0 + ): + self.exhausted = True + if inspect.isgenerator(self.gen): + self.gen.close() + + # if max items is not 0, we return the last item + # otherwise never return anything + if self.max_items != 0: + return item + + # do not return any late arriving items + if self.exhausted: + return None + + return item diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 02b52c4623..e70365b4f4 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -27,12 +27,12 @@ UnclosablePipe, ) from dlt.extract.items import ( - ItemTransform, ResolvablePipeItem, SupportsPipe, TPipeStep, TPipedDataItems, ) +from dlt.extract.items_transform import ItemTransform from dlt.extract.utils import ( check_compat_transformer, simulate_func_call, @@ -122,7 +122,23 @@ def steps(self) -> List[TPipeStep]: def find(self, *step_type: AnyType) -> int: """Finds a step with object of type `step_type`""" - return next((i for i, v in enumerate(self._steps) if isinstance(v, step_type)), -1) + found = self.find_all(step_type) + return found[0] if found else -1 + + def find_all(self, *step_type: AnyType) -> List[int]: + """Finds all steps with object of type `step_type`""" + return [i for i, v in enumerate(self._steps) if isinstance(v, step_type)] + + def get_by_type(self, *step_type: AnyType) -> TPipeStep: + """Gets first step found with object of type `step_type`""" + return next((v for v in self._steps if isinstance(v, step_type)), None) + + def remove_by_type(self, *step_type: AnyType) -> int: + """Deletes first step found with object of type `step_type`, returns previous index""" + step_index = self.find(*step_type) + if step_index >= 0: + self.remove_step(step_index) + return step_index def __getitem__(self, i: int) -> TPipeStep: return self._steps[i] diff --git a/dlt/extract/pipe_iterator.py b/dlt/extract/pipe_iterator.py index 465040f9f4..38641c0626 100644 --- a/dlt/extract/pipe_iterator.py +++ b/dlt/extract/pipe_iterator.py @@ -24,7 +24,11 @@ ) from dlt.common.configuration.container import Container from dlt.common.exceptions import PipelineException -from dlt.common.pipeline import unset_current_pipe_name, set_current_pipe_name +from dlt.common.pipeline import ( + unset_current_pipe_name, + set_current_pipe_name, + get_current_pipe_name, +) from dlt.common.utils import get_callable_name from dlt.extract.exceptions import ( @@ -180,7 +184,6 @@ def __next__(self) -> PipeItem: item = pipe_item.item # if item is iterator, then add it as a new source if isinstance(item, Iterator): - # print(f"adding iterable {item}") self._sources.append( SourcePipeItem(item, pipe_item.step, pipe_item.pipe, pipe_item.meta) ) @@ -291,7 +294,6 @@ def _get_source_item(self) -> ResolvablePipeItem: first_evaluated_index = self._current_source_index # always go round robin if None was returned or item is to be run as future self._current_source_index = (self._current_source_index - 1) % sources_count - except StopIteration: # remove empty iterator and try another source self._sources.pop(self._current_source_index) diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 42e3905162..366e6e1a88 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -2,7 +2,7 @@ from functools import partial from typing import ( AsyncIterable, - AsyncIterator, + cast, ClassVar, Callable, Iterable, @@ -34,13 +34,16 @@ from dlt.extract.items import ( DataItemWithMeta, - ItemTransformFunc, - ItemTransformFunctionWithMeta, TableNameMeta, +) +from dlt.extract.items_transform import ( FilterItem, MapItem, YieldMapItem, ValidateItem, + LimitItem, + ItemTransformFunc, + ItemTransformFunctionWithMeta, ) from dlt.extract.pipe_iterator import ManagedPipeIterator from dlt.extract.pipe import Pipe, TPipeStep @@ -214,29 +217,22 @@ def requires_args(self) -> bool: return True @property - def incremental(self) -> IncrementalResourceWrapper: + def incremental(self) -> Optional[IncrementalResourceWrapper]: """Gets incremental transform if it is in the pipe""" - incremental: IncrementalResourceWrapper = None - step_no = self._pipe.find(IncrementalResourceWrapper, Incremental) - if step_no >= 0: - incremental = self._pipe.steps[step_no] # type: ignore - return incremental + return cast( + Optional[IncrementalResourceWrapper], + self._pipe.get_by_type(IncrementalResourceWrapper, Incremental), + ) @property def validator(self) -> Optional[ValidateItem]: """Gets validator transform if it is in the pipe""" - validator: ValidateItem = None - step_no = self._pipe.find(ValidateItem) - if step_no >= 0: - validator = self._pipe.steps[step_no] # type: ignore[assignment] - return validator + return cast(Optional[ValidateItem], self._pipe.get_by_type(ValidateItem)) @validator.setter def validator(self, validator: Optional[ValidateItem]) -> None: """Add/remove or replace the validator in pipe""" - step_no = self._pipe.find(ValidateItem) - if step_no >= 0: - self._pipe.remove_step(step_no) + step_no = self._pipe.remove_by_type(ValidateItem) if validator: self.add_step(validator, insert_at=step_no if step_no >= 0 else None) @@ -347,72 +343,37 @@ def add_filter( self._pipe.insert_step(FilterItem(item_filter), insert_at) return self - def add_limit(self: TDltResourceImpl, max_items: int) -> TDltResourceImpl: # noqa: A003 + def add_limit( + self: TDltResourceImpl, + max_items: Optional[int] = None, + max_time: Optional[float] = None, + ) -> TDltResourceImpl: # noqa: A003 """Adds a limit `max_items` to the resource pipe. - This mutates the encapsulated generator to stop after `max_items` items are yielded. This is useful for testing and debugging. + This mutates the encapsulated generator to stop after `max_items` items are yielded. This is useful for testing and debugging. - Notes: - 1. Transformers won't be limited. They should process all the data they receive fully to avoid inconsistencies in generated datasets. - 2. Each yielded item may contain several records. `add_limit` only limits the "number of yields", not the total number of records. - 3. Async resources with a limit added may occasionally produce one item more than the limit on some runs. This behavior is not deterministic. + Notes: + 1. Transformers won't be limited. They should process all the data they receive fully to avoid inconsistencies in generated datasets. + 2. Each yielded item may contain several records. `add_limit` only limits the "number of yields", not the total number of records. + 3. Async resources with a limit added may occasionally produce one item more than the limit on some runs. This behavior is not deterministic. Args: - max_items (int): The maximum number of items to yield - Returns: - "DltResource": returns self + max_items (int): The maximum number of items to yield, set to None for no limit + max_time (float): The maximum number of seconds for this generator to run after it was opened, set to None for no limit + Returns: + "DltResource": returns self """ - # make sure max_items is a number, to allow "None" as value for unlimited - if max_items is None: - max_items = -1 - - def _gen_wrap(gen: TPipeStep) -> TPipeStep: - """Wrap a generator to take the first `max_items` records""" - - # zero items should produce empty generator - if max_items == 0: - return - - count = 0 - is_async_gen = False - if callable(gen): - gen = gen() # type: ignore - - # wrap async gen already here - if isinstance(gen, AsyncIterator): - gen = wrap_async_iterator(gen) - is_async_gen = True - - try: - for i in gen: # type: ignore # TODO: help me fix this later - yield i - if i is not None: - count += 1 - # async gen yields awaitable so we must count one awaitable more - # so the previous one is evaluated and yielded. - # new awaitable will be cancelled - if count == max_items + int(is_async_gen): - return - finally: - if inspect.isgenerator(gen): - gen.close() - return - - # transformers should be limited by their input, so we only limit non-transformers - if not self.is_transformer: - gen = self._pipe.gen - # wrap gen directly - if inspect.isgenerator(gen): - self._pipe.replace_gen(_gen_wrap(gen)) - else: - # keep function as function to not evaluate generators before pipe starts - self._pipe.replace_gen(partial(_gen_wrap, gen)) - else: + if self.is_transformer: logger.warning( f"Setting add_limit to a transformer {self.name} has no effect. Set the limit on" " the top level resource." ) + else: + # remove existing limit if any + self._pipe.remove_by_type(LimitItem) + self.add_step(LimitItem(max_items=max_items, max_time=max_time)) + return self def parallelize(self: TDltResourceImpl) -> TDltResourceImpl: @@ -445,9 +406,7 @@ def add_step( return self def _remove_incremental_step(self) -> None: - step_no = self._pipe.find(Incremental, IncrementalResourceWrapper) - if step_no >= 0: - self._pipe.remove_step(step_no) + self._pipe.remove_by_type(Incremental, IncrementalResourceWrapper) def set_incremental( self, diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index 68570d0995..0bcd13155e 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -183,6 +183,17 @@ def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> in return meta_arg +def wrap_iterator(gen: Iterator[TDataItems]) -> Iterator[TDataItems]: + """Wraps an iterator into a generator""" + if inspect.isgenerator(gen): + return gen + + def wrapped_gen() -> Iterator[TDataItems]: + yield from gen + + return wrapped_gen() + + def wrap_async_iterator( gen: AsyncIterator[TDataItems], ) -> Generator[Awaitable[TDataItems], None, None]: diff --git a/dlt/extract/validation.py b/dlt/extract/validation.py index 4cd321b88c..d9fe70a90b 100644 --- a/dlt/extract/validation.py +++ b/dlt/extract/validation.py @@ -8,7 +8,8 @@ from dlt.common.typing import TDataItems from dlt.common.schema.typing import TAnySchemaColumns, TSchemaContract, TSchemaEvolutionMode -from dlt.extract.items import TTableHintTemplate, ValidateItem +from dlt.extract.items import TTableHintTemplate +from dlt.extract.items_transform import ValidateItem _TPydanticModel = TypeVar("_TPydanticModel", bound=PydanticBaseModel) diff --git a/dlt/helpers/ibis.py b/dlt/helpers/ibis.py index ed4264dac7..e15bb9bc16 100644 --- a/dlt/helpers/ibis.py +++ b/dlt/helpers/ibis.py @@ -10,7 +10,7 @@ import sqlglot from ibis import BaseBackend, Expr except ModuleNotFoundError: - raise MissingDependencyException("dlt ibis Helpers", ["ibis"]) + raise MissingDependencyException("dlt ibis helpers", ["ibis-framework"]) SUPPORTED_DESTINATIONS = [ @@ -123,18 +123,22 @@ def create_ibis_backend( ) from dlt.destinations.impl.duckdb.factory import DuckDbCredentials - # we create an in memory duckdb and create all tables on there - duck = duckdb.connect(":memory:") + # we create an in memory duckdb and create the ibis backend from it fs_client = cast(FilesystemClient, client) - creds = DuckDbCredentials(duck) sql_client = FilesystemSqlClient( - fs_client, dataset_name=fs_client.dataset_name, credentials=creds + fs_client, + dataset_name=fs_client.dataset_name, + credentials=DuckDbCredentials(duckdb.connect()), ) - + # do not use context manager to not return and close the cloned connection + duckdb_conn = sql_client.open_connection() + # make all tables available here # NOTE: we should probably have the option for the user to only select a subset of tables here - with sql_client as _: - sql_client.create_views_for_all_tables() - con = ibis.duckdb.from_connection(duck) + sql_client.create_views_for_all_tables() + # why this works now: whenever a clone of connection is made, all SET commands + # apply only to it. old code was setting `curl` on the internal clone of sql_client + # now we export this clone directly to ibis to it works + con = ibis.duckdb.from_connection(duckdb_conn) return con diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 32db5034b4..1d81d70b10 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -20,7 +20,7 @@ LoadStorage, ParsedLoadJobFileName, ) -from dlt.common.schema import TSchemaUpdate, Schema +from dlt.common.schema import Schema from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.pipeline import ( NormalizeInfo, @@ -34,7 +34,7 @@ from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV -from dlt.normalize.validate import verify_normalized_table +from dlt.normalize.validate import validate_and_update_schema, verify_normalized_table # normalize worker wrapping function signature @@ -80,16 +80,6 @@ def create_storages(self) -> None: config=self.config._load_storage_config, ) - def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: - for schema_update in schema_updates: - for table_name, table_updates in schema_update.items(): - logger.info( - f"Updating schema for table {table_name} with {len(table_updates)} deltas" - ) - for partial_table in table_updates: - # merge columns where we expect identifiers to be normalized - schema.update_table(partial_table, normalize_identifiers=False) - def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: workers: int = getattr(self.pool, "_max_workers", 1) chunk_files = group_worker_files(files, workers) @@ -123,7 +113,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW result: TWorkerRV = pending.result() try: # gather schema from all manifests, validate consistency and combine - self.update_schema(schema, result[0]) + validate_and_update_schema(schema, result[0]) summary.schema_updates.extend(result.schema_updates) summary.file_metrics.extend(result.file_metrics) # update metrics @@ -162,7 +152,7 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWor load_id, files, ) - self.update_schema(schema, result.schema_updates) + validate_and_update_schema(schema, result.schema_updates) self.collector.update("Files", len(result.file_metrics)) self.collector.update( "Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count @@ -237,23 +227,11 @@ def spool_schema_files(self, load_id: str, schema: Schema, files: Sequence[str]) self.load_storage.import_extracted_package( load_id, self.normalize_storage.extracted_packages ) - logger.info(f"Created new load package {load_id} on loading volume") - try: - # process parallel - self.spool_files( - load_id, schema.clone(update_normalizers=True), self.map_parallel, files - ) - except CannotCoerceColumnException as exc: - # schema conflicts resulting from parallel executing - logger.warning( - f"Parallel schema update conflict, switching to single thread ({str(exc)}" - ) - # start from scratch - self.load_storage.new_packages.delete_package(load_id) - self.load_storage.import_extracted_package( - load_id, self.normalize_storage.extracted_packages - ) - self.spool_files(load_id, schema.clone(update_normalizers=True), self.map_single, files) + logger.info(f"Created new load package {load_id} on loading volume with ") + # get number of workers with default == 1 if not set (ie. NullExecutor) + workers: int = getattr(self.pool, "_max_workers", 1) + map_f: TMapFuncType = self.map_parallel if workers > 1 else self.map_single + self.spool_files(load_id, schema.clone(update_normalizers=True), map_f, files) return load_id diff --git a/dlt/normalize/validate.py b/dlt/normalize/validate.py index 648deb5da9..868ba3115b 100644 --- a/dlt/normalize/validate.py +++ b/dlt/normalize/validate.py @@ -1,7 +1,10 @@ +from typing import List + from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import Schema -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TSchemaUpdate from dlt.common.schema.utils import ( + ensure_compatible_tables, find_incomplete_columns, get_first_column_name_with_prop, is_nested_table, @@ -10,6 +13,21 @@ from dlt.common import logger +def validate_and_update_schema(schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: + """Updates `schema` tables with partial tables in `schema_updates`""" + for schema_update in schema_updates: + for table_name, table_updates in schema_update.items(): + logger.info(f"Updating schema for table {table_name} with {len(table_updates)} deltas") + for partial_table in table_updates: + # ensure updates will pass + if existing_table := schema.tables.get(partial_table["name"]): + ensure_compatible_tables(schema.name, existing_table, partial_table) + + for partial_table in table_updates: + # merge columns where we expect identifiers to be normalized + schema.update_table(partial_table, normalize_identifiers=False) + + def verify_normalized_table( schema: Schema, table: TTableSchema, capabilities: DestinationCapabilitiesContext ) -> None: diff --git a/dlt/sources/helpers/transform.py b/dlt/sources/helpers/transform.py index 32843e2aa2..45738fe4fb 100644 --- a/dlt/sources/helpers/transform.py +++ b/dlt/sources/helpers/transform.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Sequence, Union from dlt.common.typing import TDataItem -from dlt.extract.items import ItemTransformFunctionNoMeta +from dlt.extract.items_transform import ItemTransformFunctionNoMeta import jsonpath_ng diff --git a/docs/examples/backfill_in_chunks/__init__.py b/docs/examples/backfill_in_chunks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/backfill_in_chunks/backfill_in_chunks.py b/docs/examples/backfill_in_chunks/backfill_in_chunks.py new file mode 100644 index 0000000000..a758d67f7b --- /dev/null +++ b/docs/examples/backfill_in_chunks/backfill_in_chunks.py @@ -0,0 +1,85 @@ +""" +--- +title: Backfilling in chunks +description: Learn how to backfill in chunks of defined size +keywords: [incremental loading, backfilling, chunks,example] +--- + +In this example, you'll find a Python script that will load from a sql_database source in chunks of defined size. This is useful for backfilling in multiple pipeline runs as +opposed to backfilling in one very large pipeline run which may fail due to memory issues on ephemeral storage or just take a very long time to complete without seeing any +progress in the destination. + +We'll learn how to: + +- Connect to a mysql database with the sql_database source +- Select one table to load and apply incremental loading hints as well as the primary key +- Set the chunk size and limit the number of chunks to load in one pipeline run +- Create a pipeline and backfill the table in the defined chunks +- Use the datasets accessor to inspect and assert the load progress + +""" + +import pandas as pd + +import dlt +from dlt.sources.sql_database import sql_database + + +if __name__ == "__main__": + # NOTE: this is a live table in the rfam database, so the number of final rows may change + TOTAL_TABLE_ROWS = 4178 + RFAM_CONNECTION_STRING = "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + + # create sql database source that only loads the family table in chunks of 1000 rows + source = sql_database(RFAM_CONNECTION_STRING, table_names=["family"], chunk_size=1000) + + # we apply some hints to the table, we know the rfam_id is unique and that we can order + # and load incrementally on the created datetime column + source.family.apply_hints( + primary_key="rfam_id", + incremental=dlt.sources.incremental( + cursor_path="created", initial_value=None, row_order="asc" + ), + ) + + # with limit we can limit the number of chunks to load, with a chunk size of 1000 and a limit of 1 + # we will load 1000 rows per pipeline run + source.add_limit(1) + + # create pipeline + pipeline = dlt.pipeline( + pipeline_name="rfam", destination="duckdb", dataset_name="rfam_data", dev_mode=True + ) + + def _assert_unique_row_count(df: pd.DataFrame, num_rows: int) -> None: + """Assert that a dataframe has the correct number of unique rows""" + # NOTE: this check is dependent on reading the full table back from the destination into memory, + # so it is only useful for testing before you do a large backfill. + assert len(df) == num_rows + assert len(set(df.rfam_id.tolist())) == num_rows + + # after the first run, the family table in the destination should contain the first 1000 rows + pipeline.run(source) + _assert_unique_row_count(pipeline.dataset().family.df(), 1000) + + # after the second run, the family table in the destination should contain 1999 rows + # there is some overlap on the incremental to prevent skipping rows + pipeline.run(source) + _assert_unique_row_count(pipeline.dataset().family.df(), 1999) + + # ... + pipeline.run(source) + _assert_unique_row_count(pipeline.dataset().family.df(), 2998) + + # ... + pipeline.run(source) + _assert_unique_row_count(pipeline.dataset().family.df(), 3997) + + # the final run will load all the rows until the end of the table + pipeline.run(source) + _assert_unique_row_count(pipeline.dataset().family.df(), TOTAL_TABLE_ROWS) + + # NOTE: in a production environment you will likely: + # * be using much larger chunk sizes and limits + # * run the pipeline in a loop to load all the rows + # * and programmatically check if the table is fully loaded and abort the loop if this is the case. diff --git a/docs/website/docs/general-usage/resource.md b/docs/website/docs/general-usage/resource.md index 199eaf9b5d..b8d51caf75 100644 --- a/docs/website/docs/general-usage/resource.md +++ b/docs/website/docs/general-usage/resource.md @@ -405,11 +405,26 @@ dlt.pipeline(destination="duckdb").run(my_resource().add_limit(10)) The code above will extract `15*10=150` records. This is happening because in each iteration, 15 records are yielded, and we're limiting the number of iterations to 10. ::: -Some constraints of `add_limit` include: +Altenatively you can also apply a time limit to the resource. The code below will run the extraction for 10 seconds and extract how ever many items are yielded in that time. In combination with incrementals, this can be useful for batched loading or for loading on machines that have a run time limit. + +```py +dlt.pipeline(destination="duckdb").run(my_resource().add_limit(max_time=10)) +``` + +You can also apply a combination of both limits. In this case the extraction will stop as soon as either limit is reached. + +```py +dlt.pipeline(destination="duckdb").run(my_resource().add_limit(max_items=10, max_time=10)) +``` + + +Some notes about the `add_limit`: 1. `add_limit` does not skip any items. It closes the iterator/generator that produces data after the limit is reached. 2. You cannot limit transformers. They should process all the data they receive fully to avoid inconsistencies in generated datasets. 3. Async resources with a limit added may occasionally produce one item more than the limit on some runs. This behavior is not deterministic. +4. Calling add limit on a resource will replace any previously set limits settings. +5. For time-limited resources, the timer starts when the first item is processed. When resources are processed sequentially (FIFO mode), each resource's time limit applies also sequentially. In the default round robin mode, the time limits will usually run concurrently. :::tip If you are parameterizing the value of `add_limit` and sometimes need it to be disabled, you can set `None` or `-1` to disable the limiting. diff --git a/docs/website/docs/general-usage/source.md b/docs/website/docs/general-usage/source.md index 87c07a3e44..9c6c2aac13 100644 --- a/docs/website/docs/general-usage/source.md +++ b/docs/website/docs/general-usage/source.md @@ -107,8 +107,20 @@ load_info = pipeline.run(pipedrive_source().add_limit(10)) print(load_info) ``` +You can also apply a time limit to the source: + +```py +pipeline.run(pipedrive_source().add_limit(max_time=10)) +``` + +Or limit by both, the limit that is reached first will stop the extraction: + +```py +pipeline.run(pipedrive_source().add_limit(max_items=10, max_time=10)) +``` + :::note -Note that `add_limit` **does not limit the number of records** but rather the "number of yields". `dlt` will close the iterator/generator that produces data after the limit is reached. +Note that `add_limit` **does not limit the number of records** but rather the "number of yields". `dlt` will close the iterator/generator that produces data after the limit is reached. Please read in more detail about the `add_limit` on the resource page. ::: Find more on sampling data [here](resource.md#sample-from-large-data). diff --git a/docs/website/docs/reference/command-line-interface.md b/docs/website/docs/reference/command-line-interface.md index 825d33d548..2af750f43c 100644 --- a/docs/website/docs/reference/command-line-interface.md +++ b/docs/website/docs/reference/command-line-interface.md @@ -20,9 +20,22 @@ This command creates a new dlt pipeline script that loads data from `source` to This command can be used several times in the same folder to add more sources, destinations, and pipelines. It will also update the verified source code to the newest version if run again with an existing `source` name. You are warned if files will be overwritten or if the `dlt` version needs an upgrade to run a particular pipeline. +### Ejecting source code of the core sources like `sql_database`. +We merged a few sources to the core library. You can still eject source code and hack them with the `--eject` flag: +```sh +dlt init sql_database duckdb --eject +``` +will copy the source code of `sql_database` to your project. Remember to modify the pipeline example script to import from the local folder! + ### Specify your own "verified sources" repository You can use the `--location ` option to specify your own repository with sources. Typically, you would [fork ours](https://github.com/dlt-hub/verified-sources) and start customizing and adding sources, e.g., to use them for your team or organization. You can also specify a branch with `--branch `, e.g., to test a version being developed. +### Using dlt 0.5.x sources +Use `--branch 0.5` if you are still on `dlt` `0.5.x` ie. +```sh +dlt init --branch 0.5 +``` + ### List all sources ```sh dlt init --list-sources diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index 8e1affd164..d81cd8c858 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -262,6 +262,21 @@ def test_init_all_sources_isolated(cloned_init_repo: FileStorage) -> None: assert_index_version_constraint(files, candidate) +def test_init_core_sources_ejected(cloned_init_repo: FileStorage) -> None: + repo_dir = get_repo_dir(cloned_init_repo) + # ensure we test both sources form verified sources and core sources + source_candidates = set(CORE_SOURCES) + for candidate in source_candidates: + clean_test_storage() + repo_dir = get_repo_dir(cloned_init_repo) + files = get_project_files(clear_all_sources=False) + with set_working_dir(files.storage_path): + init_command.init_command(candidate, "bigquery", repo_dir, eject_source=True) + assert_requirements_txt(files, "bigquery") + # check if files copied + assert files.has_folder(candidate) + + @pytest.mark.parametrize("destination_name", IMPLEMENTED_DESTINATIONS) def test_init_all_destinations( destination_name: str, project_files: FileStorage, repo_dir: str @@ -279,25 +294,6 @@ def test_custom_destination_note(repo_dir: str, project_files: FileStorage): assert "to add a destination function that will consume your data" in _out -@pytest.mark.parametrize("omit", [True, False]) -# this will break if we have new core sources that are not in verified sources anymore -@pytest.mark.parametrize("source", set(CORE_SOURCES) - {"rest_api"}) -def test_omit_core_sources( - source: str, omit: bool, project_files: FileStorage, repo_dir: str -) -> None: - with io.StringIO() as buf, contextlib.redirect_stdout(buf): - init_command.init_command(source, "destination", repo_dir, omit_core_sources=omit) - _out = buf.getvalue() - - # check messaging - assert ("Omitting dlt core sources" in _out) == omit - assert ("will no longer be copied from the" in _out) == (not omit) - - # if we omit core sources, there will be a folder with the name of the source from the verified sources repo - assert project_files.has_folder(source) == omit - assert (f"dlt.sources.{source}" in project_files.load(f"{source}_pipeline.py")) == (not omit) - - def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) -> None: sources_storage = FileStorage(os.path.join(repo_dir, SOURCES_MODULE_NAME)) new_content = '"""New docstrings"""' diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index 7f06cdb71e..adbb34b1f0 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -441,27 +441,6 @@ def test_update_schema_table_prop_conflict(schema: Schema) -> None: assert exc_val.value.val2 == "tab_parent" -def test_update_schema_column_conflict(schema: Schema) -> None: - tab1 = utils.new_table( - "tab1", - write_disposition="append", - columns=[ - {"name": "col1", "data_type": "text", "nullable": False}, - ], - ) - schema.update_table(tab1) - tab1_u1 = deepcopy(tab1) - # simulate column that had other datatype inferred - tab1_u1["columns"]["col1"]["data_type"] = "bool" - with pytest.raises(CannotCoerceColumnException) as exc_val: - schema.update_table(tab1_u1) - assert exc_val.value.column_name == "col1" - assert exc_val.value.from_type == "bool" - assert exc_val.value.to_type == "text" - # whole column mismatch - assert exc_val.value.coerced_value is None - - def _add_preferred_types(schema: Schema) -> None: schema._settings["preferred_types"] = {} schema._settings["preferred_types"][TSimpleRegex("timestamp")] = "timestamp" diff --git a/tests/common/schema/test_merges.py b/tests/common/schema/test_merges.py index 8e0c350e7c..b76fe944b5 100644 --- a/tests/common/schema/test_merges.py +++ b/tests/common/schema/test_merges.py @@ -353,7 +353,7 @@ def test_diff_tables() -> None: assert "test" in partial["columns"] -def test_diff_tables_conflicts() -> None: +def test_tables_conflicts() -> None: # conflict on parents table: TTableSchema = { # type: ignore[typeddict-unknown-key] "name": "table", @@ -366,6 +366,8 @@ def test_diff_tables_conflicts() -> None: other = utils.new_table("table") with pytest.raises(TablePropertiesConflictException) as cf_ex: utils.diff_table("schema", table, other) + with pytest.raises(TablePropertiesConflictException) as cf_ex: + utils.ensure_compatible_tables("schema", table, other) assert cf_ex.value.table_name == "table" assert cf_ex.value.prop_name == "parent" @@ -373,6 +375,8 @@ def test_diff_tables_conflicts() -> None: other = utils.new_table("other_name") with pytest.raises(TablePropertiesConflictException) as cf_ex: utils.diff_table("schema", table, other) + with pytest.raises(TablePropertiesConflictException) as cf_ex: + utils.ensure_compatible_tables("schema", table, other) assert cf_ex.value.table_name == "table" assert cf_ex.value.prop_name == "name" @@ -380,7 +384,10 @@ def test_diff_tables_conflicts() -> None: changed = deepcopy(table) changed["columns"]["test"]["data_type"] = "bigint" with pytest.raises(CannotCoerceColumnException): - utils.diff_table("schema", table, changed) + utils.ensure_compatible_tables("schema", table, changed) + # but diff now accepts different data types + merged_table = utils.diff_table("schema", table, changed) + assert merged_table["columns"]["test"]["data_type"] == "bigint" def test_merge_tables() -> None: diff --git a/tests/extract/test_extract_pipe.py b/tests/extract/test_extract_pipe.py index d40639a594..659888269a 100644 --- a/tests/extract/test_extract_pipe.py +++ b/tests/extract/test_extract_pipe.py @@ -10,7 +10,8 @@ from dlt.common import sleep from dlt.common.typing import TDataItems from dlt.extract.exceptions import CreatePipeException, ResourceExtractionError, UnclosablePipe -from dlt.extract.items import DataItemWithMeta, FilterItem, MapItem, YieldMapItem +from dlt.extract.items import DataItemWithMeta +from dlt.extract.items_transform import FilterItem, MapItem, YieldMapItem from dlt.extract.pipe import Pipe from dlt.extract.pipe_iterator import PipeIterator, ManagedPipeIterator, PipeItem diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index d63dac93f2..9ad7d28e88 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -7,6 +7,7 @@ from time import sleep from typing import Any, Optional, Literal, Sequence, Dict, Iterable from unittest import mock +import itertools import duckdb import pyarrow as pa @@ -35,7 +36,7 @@ IncrementalPrimaryKeyMissing, ) from dlt.extract.incremental.lag import apply_lag -from dlt.extract.items import ValidateItem +from dlt.extract.items_transform import ValidateItem from dlt.extract.resource import DltResource from dlt.pipeline.exceptions import PipelineStepFailed from dlt.sources.helpers.transform import take_first @@ -3960,3 +3961,58 @@ def some_data( # Includes values 5-10 inclusive assert items == expected_items + + +@pytest.mark.parametrize("offset_by_last_value", [True, False]) +def test_incremental_and_limit(offset_by_last_value: bool): + resource_called = 0 + + # here we check incremental and limit when incremental once when last value cannot be used + # to offset the source, and once when it can. + + @dlt.resource( + table_name="items", + ) + def resource( + incremental=dlt.sources.incremental(cursor_path="id", initial_value=-1, row_order="asc") + ): + range_iterator = ( + range(incremental.start_value + 1, 1000) if offset_by_last_value else range(1000) + ) + for i in range_iterator: + nonlocal resource_called + resource_called += 1 + yield { + "id": i, + "value": str(i), + } + + resource.add_limit(10) + + p = dlt.pipeline(pipeline_name="incremental_limit", destination="duckdb", dev_mode=True) + + p.run(resource()) + + # check we have the right number of items + assert len(p.dataset().items.df()) == 10 + assert resource_called == 10 + # check that we have items 0-9 + assert p.dataset().items.df().id.tolist() == list(range(10)) + + # run the next ten + p.run(resource()) + + # check we have the right number of items + assert len(p.dataset().items.df()) == 20 + assert resource_called == 20 if offset_by_last_value else 30 + # check that we have items 0-19 + assert p.dataset().items.df().id.tolist() == list(range(20)) + + # run the next batch + p.run(resource()) + + # check we have the right number of items + assert len(p.dataset().items.df()) == 30 + assert resource_called == 30 if offset_by_last_value else 60 + # check that we have items 0-29 + assert p.dataset().items.df().id.tolist() == list(range(30)) diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index 3d021d5d10..86646e6369 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -1,4 +1,6 @@ import itertools +import time + from typing import Iterator import pytest @@ -837,7 +839,7 @@ def test_limit_infinite_counter() -> None: @pytest.mark.parametrize("limit", (None, -1, 0, 10)) def test_limit_edge_cases(limit: int) -> None: - r = dlt.resource(range(20), name="infinity").add_limit(limit) # type: ignore + r = dlt.resource(range(20), name="resource").add_limit(limit) # type: ignore @dlt.resource() async def r_async(): @@ -845,22 +847,62 @@ async def r_async(): await asyncio.sleep(0.01) yield i + @dlt.resource(parallelized=True) + def parallelized_resource(): + for i in range(20): + yield i + sync_list = list(r) async_list = list(r_async().add_limit(limit)) + parallelized_list = list(parallelized_resource().add_limit(limit)) + + # all lists should be the same + assert sync_list == async_list == parallelized_list if limit == 10: assert sync_list == list(range(10)) - # we have edge cases where the async list will have one extra item - # possibly due to timing issues, maybe some other implementation problem - assert (async_list == list(range(10))) or (async_list == list(range(11))) elif limit in [None, -1]: - assert sync_list == async_list == list(range(20)) + assert sync_list == list(range(20)) elif limit == 0: - assert sync_list == async_list == [] + assert sync_list == [] else: raise AssertionError(f"Unexpected limit: {limit}") +def test_various_limit_setups() -> None: + # basic test + r = dlt.resource([1, 2, 3, 4, 5], name="test").add_limit(3) + assert list(r) == [1, 2, 3] + + # yield map test + r = ( + dlt.resource([1, 2, 3, 4, 5], name="test") + .add_map(lambda i: str(i) * i, 1) + .add_yield_map(lambda i: (yield from i)) + .add_limit(3) + ) + # limit is applied at the end + assert list(r) == ["1", "2", "2"] # "3" ,"3" ,"3" ,"4" ,"4" ,"4" ,"4", ...] + + # nested lists test (limit only applied to yields, not actual items) + r = dlt.resource([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], name="test").add_limit(3) + assert list(r) == [1, 2, 3, 4, 5, 6, 7, 8, 9] + + # transformer test + r = dlt.resource([1, 2, 3, 4, 5], name="test").add_limit(4) + t = dlt.transformer(lambda i: i * 2, name="test") + assert list(r) == [1, 2, 3, 4] + assert list(r | t) == [2, 4, 6, 8] + + # adding limit to transformer is disregarded + t = t.add_limit(2) + assert list(r | t) == [2, 4, 6, 8] + + # limits are fully replaced (more genereous limit applied later takes precedence) + r = dlt.resource([1, 2, 3, 4, 5], name="test").add_limit(3).add_limit(4) + assert list(r) == [1, 2, 3, 4] + + def test_limit_source() -> None: def mul_c(item): yield from "A" * (item + 2) @@ -876,6 +918,30 @@ def infinite_source(): assert list(infinite_source().add_limit(2)) == ["A", "A", 0, "A", "A", "A", 1] * 3 +def test_limit_max_time() -> None: + @dlt.resource() + def r(): + for i in range(100): + time.sleep(0.1) + yield i + + @dlt.resource() + async def r_async(): + for i in range(100): + await asyncio.sleep(0.1) + yield i + + sync_list = list(r().add_limit(max_time=1)) + async_list = list(r_async().add_limit(max_time=1)) + + # we should have extracted 10 items within 1 second, sleep is included in the resource + # we allow for some variance in the number of items, as the sleep is not super precise + # on mac os we even sometimes just get 4 items... + allowed_results = [list(range(i)) for i in [12, 11, 10, 9, 8, 7, 6, 5, 4]] + assert sync_list in allowed_results + assert async_list in allowed_results + + def test_source_state() -> None: @dlt.source def test_source(expected_state): diff --git a/tests/extract/test_validation.py b/tests/extract/test_validation.py index 138589bb06..3800f333f6 100644 --- a/tests/extract/test_validation.py +++ b/tests/extract/test_validation.py @@ -10,7 +10,7 @@ from dlt.common.libs.pydantic import BaseModel from dlt.extract import DltResource -from dlt.extract.items import ValidateItem +from dlt.extract.items_transform import ValidateItem from dlt.extract.validation import PydanticValidator from dlt.extract.exceptions import ResourceExtractionError from dlt.pipeline.exceptions import PipelineStepFailed diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 7364ef7243..f1de3de093 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -6,7 +6,7 @@ from dlt.common.typing import TDataItem, TDataItems from dlt.extract.extract import ExtractStorage -from dlt.extract.items import ItemTransform +from dlt.extract.items_transform import ItemTransform from tests.utils import TestDataItemFormat diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 9190225a8c..b998b78471 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -10,7 +10,6 @@ from dlt.common.pipeline import SupportsPipeline from dlt.common.destination import Destination from dlt.common.destination.reference import WithStagingDataset -from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.schema.utils import new_table diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index 29ad21941e..e09582f8a8 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -127,177 +127,6 @@ def test_pipeline_explicit_destination_credentials( ) -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_pipeline_with_sources_sharing_schema( - destination_config: DestinationTestConfiguration, -) -> None: - schema = Schema("shared") - - @dlt.source(schema=schema, max_table_nesting=1) - def source_1(): - @dlt.resource(primary_key="user_id") - def gen1(): - dlt.current.source_state()["source_1"] = True - dlt.current.resource_state()["source_1"] = True - yield {"id": "Y", "user_id": "user_y"} - - @dlt.resource(columns={"col": {"data_type": "bigint"}}) - def conflict(): - yield "conflict" - - return gen1, conflict - - @dlt.source(schema=schema, max_table_nesting=2) - def source_2(): - @dlt.resource(primary_key="id") - def gen1(): - dlt.current.source_state()["source_2"] = True - dlt.current.resource_state()["source_2"] = True - yield {"id": "X", "user_id": "user_X"} - - def gen2(): - yield from "CDE" - - @dlt.resource(columns={"col": {"data_type": "bool"}}, selected=False) - def conflict(): - yield "conflict" - - return gen2, gen1, conflict - - # all selected tables with hints should be there - discover_1 = source_1().discover_schema() - assert "gen1" in discover_1.tables - assert discover_1.tables["gen1"]["columns"]["user_id"]["primary_key"] is True - assert "data_type" not in discover_1.tables["gen1"]["columns"]["user_id"] - assert "conflict" in discover_1.tables - assert discover_1.tables["conflict"]["columns"]["col"]["data_type"] == "bigint" - - discover_2 = source_2().discover_schema() - assert "gen1" in discover_2.tables - assert "gen2" in discover_2.tables - # conflict deselected - assert "conflict" not in discover_2.tables - - p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) - p.extract([source_1(), source_2()], table_format=destination_config.table_format) - default_schema = p.default_schema - gen1_table = default_schema.tables["gen1"] - assert "user_id" in gen1_table["columns"] - assert "id" in gen1_table["columns"] - assert "conflict" in default_schema.tables - assert "gen2" in default_schema.tables - p.normalize(loader_file_format=destination_config.file_format) - assert "gen2" in default_schema.tables - p.load() - table_names = [t["name"] for t in default_schema.data_tables()] - counts = load_table_counts(p, *table_names) - assert counts == {"gen1": 2, "gen2": 3, "conflict": 1} - # both sources share the same state - assert p.state["sources"] == { - "shared": { - "source_1": True, - "resources": {"gen1": {"source_1": True, "source_2": True}}, - "source_2": True, - } - } - drop_active_pipeline_data() - - # same pipeline but enable conflict - p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) - with pytest.raises(PipelineStepFailed) as py_ex: - p.extract([source_1(), source_2().with_resources("conflict")]) - assert isinstance(py_ex.value.__context__, CannotCoerceColumnException) - - -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_many_pipelines_single_dataset(destination_config: DestinationTestConfiguration) -> None: - schema = Schema("shared") - - @dlt.source(schema=schema, max_table_nesting=1) - def source_1(): - @dlt.resource(primary_key="user_id") - def gen1(): - dlt.current.source_state()["source_1"] = True - dlt.current.resource_state()["source_1"] = True - yield {"id": "Y", "user_id": "user_y"} - - return gen1 - - @dlt.source(schema=schema, max_table_nesting=2) - def source_2(): - @dlt.resource(primary_key="id") - def gen1(): - dlt.current.source_state()["source_2"] = True - dlt.current.resource_state()["source_2"] = True - yield {"id": "X", "user_id": "user_X"} - - def gen2(): - yield from "CDE" - - return gen2, gen1 - - # load source_1 to common dataset - p = dlt.pipeline( - pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" - ) - p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") - counts = load_table_counts(p, *p.default_schema.tables.keys()) - assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() - p._wipe_working_folder() - p.deactivate() - - p = dlt.pipeline( - pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" - ) - p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") - # table_names = [t["name"] for t in p.default_schema.data_tables()] - counts = load_table_counts(p, *p.default_schema.tables.keys()) - # gen1: one record comes from source_1, 1 record from source_2 - assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() - # assert counts == {'gen1': 2, 'gen2': 3} - p._wipe_working_folder() - p.deactivate() - - # restore from destination, check state - p = dlt.pipeline( - pipeline_name="source_1_pipeline", - destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), - dataset_name="shared_dataset", - ) - p.sync_destination() - # we have our separate state - assert p.state["sources"]["shared"] == { - "source_1": True, - "resources": {"gen1": {"source_1": True}}, - } - # but the schema was common so we have the earliest one - assert "gen2" in p.default_schema.tables - p._wipe_working_folder() - p.deactivate() - - p = dlt.pipeline( - pipeline_name="source_2_pipeline", - destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), - dataset_name="shared_dataset", - ) - p.sync_destination() - # we have our separate state - assert p.state["sources"]["shared"] == { - "source_2": True, - "resources": {"gen1": {"source_2": True}}, - } - - # TODO: uncomment and finalize when we implement encoding for psycopg2 # @pytest.mark.parametrize( # "destination_config", diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 7463184be7..84e22af9ff 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -1,3 +1,4 @@ +from copy import deepcopy import pytest from fnmatch import fnmatch from typing import Dict, Iterator, List, Sequence, Tuple @@ -5,6 +6,7 @@ from dlt.common import json from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema from dlt.common.schema.utils import new_table from dlt.common.storages.exceptions import SchemaNotFoundError @@ -16,6 +18,7 @@ from dlt.extract.extract import ExtractStorage from dlt.normalize import Normalize +from dlt.normalize.validate import validate_and_update_schema from dlt.normalize.worker import group_worker_files from dlt.normalize.exceptions import NormalizeJobFailed @@ -284,6 +287,8 @@ def test_multiprocessing_row_counting( extract_cases(raw_normalize, ["github.events.load_page_1_duck"]) # use real process pool in tests with ProcessPoolExecutor(max_workers=4) as p: + # test if we get correct number of workers + assert getattr(p, "_max_workers", None) == 4 raw_normalize.run(p) # get step info step_info = raw_normalize.get_step_info(MockPipeline("multiprocessing_pipeline", True)) # type: ignore[abstract] @@ -712,6 +717,71 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type +def test_update_schema_column_conflict(rasa_normalize: Normalize) -> None: + extract_cases( + rasa_normalize, + [ + "event.event.many_load_2", + "event.event.user_load_1", + ], + ) + extract_cases( + rasa_normalize, + [ + "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2", + ], + ) + # use real process pool in tests + with ProcessPoolExecutor(max_workers=4) as p: + rasa_normalize.run(p) + + schema = rasa_normalize.schema_storage.load_schema("event") + tab1 = new_table( + "event_user", + write_disposition="append", + columns=[ + {"name": "col1", "data_type": "text", "nullable": False}, + ], + ) + validate_and_update_schema(schema, [{"event_user": [deepcopy(tab1)]}]) + assert schema.tables["event_user"]["columns"]["col1"]["data_type"] == "text" + + tab1["columns"]["col1"]["data_type"] = "bool" + tab1["columns"]["col2"] = {"name": "col2", "data_type": "text", "nullable": False} + with pytest.raises(CannotCoerceColumnException) as exc_val: + validate_and_update_schema(schema, [{"event_user": [deepcopy(tab1)]}]) + assert exc_val.value.column_name == "col1" + assert exc_val.value.from_type == "bool" + assert exc_val.value.to_type == "text" + # whole column mismatch + assert exc_val.value.coerced_value is None + # make sure col2 is not added + assert "col2" not in schema.tables["event_user"]["columns"] + + # add two updates that are conflicting + tab2 = new_table( + "event_slot", + write_disposition="append", + columns=[ + {"name": "col1", "data_type": "text", "nullable": False}, + {"name": "col2", "data_type": "text", "nullable": False}, + ], + ) + tab3 = new_table( + "event_slot", + write_disposition="append", + columns=[ + {"name": "col1", "data_type": "bool", "nullable": False}, + ], + ) + with pytest.raises(CannotCoerceColumnException) as exc_val: + validate_and_update_schema( + schema, [{"event_slot": [deepcopy(tab2)]}, {"event_slot": [deepcopy(tab3)]}] + ) + # col2 is added from first update + assert "col2" in schema.tables["event_slot"]["columns"] + + def test_removal_of_normalizer_schema_section_and_add_seen_data(raw_normalize: Normalize) -> None: extract_cases( raw_normalize, diff --git a/tests/pipeline/test_import_export_schema.py b/tests/pipeline/test_import_export_schema.py index eb36d36ba3..5eb9c664d0 100644 --- a/tests/pipeline/test_import_export_schema.py +++ b/tests/pipeline/test_import_export_schema.py @@ -1,4 +1,4 @@ -import dlt, os, pytest +import dlt, os from dlt.common.utils import uniq_id @@ -6,8 +6,6 @@ from tests.utils import TEST_STORAGE_ROOT from dlt.common.schema import Schema from dlt.common.storages.schema_storage import SchemaStorage -from dlt.common.schema.exceptions import CannotCoerceColumnException -from dlt.pipeline.exceptions import PipelineStepFailed from dlt.destinations import dummy diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 2d72e23462..95d464d48a 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -52,7 +52,7 @@ from dlt.pipeline.pipeline import Pipeline from tests.common.utils import TEST_SENTRY_DSN -from tests.utils import TEST_STORAGE_ROOT +from tests.utils import TEST_STORAGE_ROOT, load_table_counts from tests.extract.utils import expect_extracted_file from tests.pipeline.utils import ( assert_data_table_counts, @@ -3011,3 +3011,171 @@ def test_push_table_with_upfront_schema() -> None: copy_pipeline = dlt.pipeline(pipeline_name="push_table_copy_pipeline", destination="duckdb") info = copy_pipeline.run(data, table_name="events", schema=copy_schema) assert copy_pipeline.default_schema.version_hash != infer_hash + + +def test_pipeline_with_sources_sharing_schema() -> None: + schema = Schema("shared") + + @dlt.source(schema=schema, max_table_nesting=1) + def source_1(): + @dlt.resource(primary_key="user_id") + def gen1(): + dlt.current.source_state()["source_1"] = True + dlt.current.resource_state()["source_1"] = True + yield {"id": "Y", "user_id": "user_y"} + + @dlt.resource(columns={"value": {"data_type": "bool"}}) + def conflict(): + yield True + + return gen1, conflict + + @dlt.source(schema=schema, max_table_nesting=2) + def source_2(): + @dlt.resource(primary_key="id") + def gen1(): + dlt.current.source_state()["source_2"] = True + dlt.current.resource_state()["source_2"] = True + yield {"id": "X", "user_id": "user_X"} + + def gen2(): + yield from "CDE" + + @dlt.resource(columns={"value": {"data_type": "text"}}, selected=False) + def conflict(): + yield "indeed" + + return gen2, gen1, conflict + + # all selected tables with hints should be there + discover_1 = source_1().discover_schema() + assert "gen1" in discover_1.tables + assert discover_1.tables["gen1"]["columns"]["user_id"]["primary_key"] is True + assert "data_type" not in discover_1.tables["gen1"]["columns"]["user_id"] + assert "conflict" in discover_1.tables + assert discover_1.tables["conflict"]["columns"]["value"]["data_type"] == "bool" + + discover_2 = source_2().discover_schema() + assert "gen1" in discover_2.tables + assert "gen2" in discover_2.tables + # conflict deselected + assert "conflict" not in discover_2.tables + + p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) + p.extract([source_1(), source_2()]) + default_schema = p.default_schema + gen1_table = default_schema.tables["gen1"] + assert "user_id" in gen1_table["columns"] + assert "id" in gen1_table["columns"] + assert "conflict" in default_schema.tables + assert "gen2" in default_schema.tables + p.normalize() + assert "gen2" in default_schema.tables + assert default_schema.tables["conflict"]["columns"]["value"]["data_type"] == "bool" + p.load() + table_names = [t["name"] for t in default_schema.data_tables()] + counts = load_table_counts(p, *table_names) + assert counts == {"gen1": 2, "gen2": 3, "conflict": 1} + # both sources share the same state + assert p.state["sources"] == { + "shared": { + "source_1": True, + "resources": {"gen1": {"source_1": True, "source_2": True}}, + "source_2": True, + } + } + + # same pipeline but enable conflict + p.extract([source_2().with_resources("conflict")]) + p.normalize() + assert default_schema.tables["conflict"]["columns"]["value"]["data_type"] == "text" + with pytest.raises(PipelineStepFailed): + # will generate failed job on type that does not match + p.load() + counts = load_table_counts(p, "conflict") + assert counts == {"conflict": 1} + + # alter table in duckdb + with p.sql_client() as client: + client.execute_sql("ALTER TABLE conflict ALTER value TYPE VARCHAR;") + p.run([source_2().with_resources("conflict")]) + counts = load_table_counts(p, "conflict") + assert counts == {"conflict": 2} + + +def test_many_pipelines_single_dataset() -> None: + schema = Schema("shared") + + @dlt.source(schema=schema, max_table_nesting=1) + def source_1(): + @dlt.resource(primary_key="user_id") + def gen1(): + dlt.current.source_state()["source_1"] = True + dlt.current.resource_state()["source_1"] = True + yield {"id": "Y", "user_id": "user_y"} + + return gen1 + + @dlt.source(schema=schema, max_table_nesting=2) + def source_2(): + @dlt.resource(primary_key="id") + def gen1(): + dlt.current.source_state()["source_2"] = True + dlt.current.resource_state()["source_2"] = True + yield {"id": "X", "user_id": "user_X"} + + def gen2(): + yield from "CDE" + + return gen2, gen1 + + # load source_1 to common dataset + p = dlt.pipeline( + pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) + p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") + counts = load_table_counts(p, *p.default_schema.tables.keys()) + assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() + p._wipe_working_folder() + p.deactivate() + + p = dlt.pipeline( + pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) + p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") + # table_names = [t["name"] for t in p.default_schema.data_tables()] + counts = load_table_counts(p, *p.default_schema.tables.keys()) + # gen1: one record comes from source_1, 1 record from source_2 + assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() + # assert counts == {'gen1': 2, 'gen2': 3} + p._wipe_working_folder() + p.deactivate() + + # restore from destination, check state + p = dlt.pipeline( + pipeline_name="source_1_pipeline", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), + dataset_name="shared_dataset", + ) + p.sync_destination() + # we have our separate state + assert p.state["sources"]["shared"] == { + "source_1": True, + "resources": {"gen1": {"source_1": True}}, + } + # but the schema was common so we have the earliest one + assert "gen2" in p.default_schema.tables + p._wipe_working_folder() + p.deactivate() + + p = dlt.pipeline( + pipeline_name="source_2_pipeline", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), + dataset_name="shared_dataset", + ) + p.sync_destination() + # we have our separate state + assert p.state["sources"]["shared"] == { + "source_2": True, + "resources": {"gen1": {"source_2": True}}, + }