From d81a35b39aee8e01d3db8c5a483e5662fcc0dce0 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 15 Dec 2024 23:53:04 +0100 Subject: [PATCH] allows data type diff and ensures valid migration separately --- dlt/common/schema/utils.py | 70 +++++--- dlt/normalize/normalize.py | 40 ++--- dlt/normalize/validate.py | 20 ++- tests/common/schema/test_inference.py | 21 --- tests/common/schema/test_merges.py | 11 +- tests/load/pipeline/test_pipelines.py | 1 - tests/load/pipeline/test_postgres.py | 171 -------------------- tests/normalize/test_normalize.py | 70 ++++++++ tests/pipeline/test_import_export_schema.py | 4 +- tests/pipeline/test_pipeline.py | 171 +++++++++++++++++++- 10 files changed, 326 insertions(+), 253 deletions(-) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 038abdc4d0..4f9e0eb42e 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -457,16 +457,8 @@ def diff_table( * when columns with the same name have different data types * when table links to different parent tables """ - if tab_a["name"] != tab_b["name"]: - raise TablePropertiesConflictException( - schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"] - ) - table_name = tab_a["name"] - # check if table properties can be merged - if tab_a.get("parent") != tab_b.get("parent"): - raise TablePropertiesConflictException( - schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent") - ) + # allow for columns to differ + ensure_compatible_tables(schema_name, tab_a, tab_b, ensure_columns=False) # get new columns, changes in the column data type or other properties are not allowed tab_a_columns = tab_a["columns"] @@ -474,18 +466,6 @@ def diff_table( for col_b_name, col_b in tab_b["columns"].items(): if col_b_name in tab_a_columns: col_a = tab_a_columns[col_b_name] - # we do not support changing data types of columns - if is_complete_column(col_a) and is_complete_column(col_b): - if not compare_complete_columns(tab_a_columns[col_b_name], col_b): - # attempt to update to incompatible columns - raise CannotCoerceColumnException( - schema_name, - table_name, - col_b_name, - col_b["data_type"], - tab_a_columns[col_b_name]["data_type"], - None, - ) # all other properties can change merged_column = merge_column(copy(col_a), col_b) if merged_column != col_a: @@ -494,6 +474,8 @@ def diff_table( new_columns.append(col_b) # return partial table containing only name and properties that differ (column, filters etc.) + table_name = tab_a["name"] + partial_table: TPartialTableSchema = { "name": table_name, "columns": {} if new_columns is None else {c["name"]: c for c in new_columns}, @@ -519,6 +501,50 @@ def diff_table( return partial_table +def ensure_compatible_tables( + schema_name: str, tab_a: TTableSchema, tab_b: TPartialTableSchema, ensure_columns: bool = True +) -> None: + """Ensures that `tab_a` and `tab_b` can be merged without conflicts. Conflicts are detected when + + - tables have different names + - nested tables have different parents + - tables have any column with incompatible types + + Note: all the identifiers must be already normalized + + """ + if tab_a["name"] != tab_b["name"]: + raise TablePropertiesConflictException( + schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"] + ) + table_name = tab_a["name"] + # check if table properties can be merged + if tab_a.get("parent") != tab_b.get("parent"): + raise TablePropertiesConflictException( + schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent") + ) + + if not ensure_columns: + return + + tab_a_columns = tab_a["columns"] + for col_b_name, col_b in tab_b["columns"].items(): + if col_b_name in tab_a_columns: + col_a = tab_a_columns[col_b_name] + # we do not support changing data types of columns + if is_complete_column(col_a) and is_complete_column(col_b): + if not compare_complete_columns(tab_a_columns[col_b_name], col_b): + # attempt to update to incompatible columns + raise CannotCoerceColumnException( + schema_name, + table_name, + col_b_name, + col_b["data_type"], + tab_a_columns[col_b_name]["data_type"], + None, + ) + + # def compare_tables(tab_a: TTableSchema, tab_b: TTableSchema) -> bool: # try: # table_name = tab_a["name"] diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 32db5034b4..1d81d70b10 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -20,7 +20,7 @@ LoadStorage, ParsedLoadJobFileName, ) -from dlt.common.schema import TSchemaUpdate, Schema +from dlt.common.schema import Schema from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.pipeline import ( NormalizeInfo, @@ -34,7 +34,7 @@ from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV -from dlt.normalize.validate import verify_normalized_table +from dlt.normalize.validate import validate_and_update_schema, verify_normalized_table # normalize worker wrapping function signature @@ -80,16 +80,6 @@ def create_storages(self) -> None: config=self.config._load_storage_config, ) - def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: - for schema_update in schema_updates: - for table_name, table_updates in schema_update.items(): - logger.info( - f"Updating schema for table {table_name} with {len(table_updates)} deltas" - ) - for partial_table in table_updates: - # merge columns where we expect identifiers to be normalized - schema.update_table(partial_table, normalize_identifiers=False) - def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: workers: int = getattr(self.pool, "_max_workers", 1) chunk_files = group_worker_files(files, workers) @@ -123,7 +113,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW result: TWorkerRV = pending.result() try: # gather schema from all manifests, validate consistency and combine - self.update_schema(schema, result[0]) + validate_and_update_schema(schema, result[0]) summary.schema_updates.extend(result.schema_updates) summary.file_metrics.extend(result.file_metrics) # update metrics @@ -162,7 +152,7 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWor load_id, files, ) - self.update_schema(schema, result.schema_updates) + validate_and_update_schema(schema, result.schema_updates) self.collector.update("Files", len(result.file_metrics)) self.collector.update( "Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count @@ -237,23 +227,11 @@ def spool_schema_files(self, load_id: str, schema: Schema, files: Sequence[str]) self.load_storage.import_extracted_package( load_id, self.normalize_storage.extracted_packages ) - logger.info(f"Created new load package {load_id} on loading volume") - try: - # process parallel - self.spool_files( - load_id, schema.clone(update_normalizers=True), self.map_parallel, files - ) - except CannotCoerceColumnException as exc: - # schema conflicts resulting from parallel executing - logger.warning( - f"Parallel schema update conflict, switching to single thread ({str(exc)}" - ) - # start from scratch - self.load_storage.new_packages.delete_package(load_id) - self.load_storage.import_extracted_package( - load_id, self.normalize_storage.extracted_packages - ) - self.spool_files(load_id, schema.clone(update_normalizers=True), self.map_single, files) + logger.info(f"Created new load package {load_id} on loading volume with ") + # get number of workers with default == 1 if not set (ie. NullExecutor) + workers: int = getattr(self.pool, "_max_workers", 1) + map_f: TMapFuncType = self.map_parallel if workers > 1 else self.map_single + self.spool_files(load_id, schema.clone(update_normalizers=True), map_f, files) return load_id diff --git a/dlt/normalize/validate.py b/dlt/normalize/validate.py index 648deb5da9..868ba3115b 100644 --- a/dlt/normalize/validate.py +++ b/dlt/normalize/validate.py @@ -1,7 +1,10 @@ +from typing import List + from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import Schema -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TSchemaUpdate from dlt.common.schema.utils import ( + ensure_compatible_tables, find_incomplete_columns, get_first_column_name_with_prop, is_nested_table, @@ -10,6 +13,21 @@ from dlt.common import logger +def validate_and_update_schema(schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: + """Updates `schema` tables with partial tables in `schema_updates`""" + for schema_update in schema_updates: + for table_name, table_updates in schema_update.items(): + logger.info(f"Updating schema for table {table_name} with {len(table_updates)} deltas") + for partial_table in table_updates: + # ensure updates will pass + if existing_table := schema.tables.get(partial_table["name"]): + ensure_compatible_tables(schema.name, existing_table, partial_table) + + for partial_table in table_updates: + # merge columns where we expect identifiers to be normalized + schema.update_table(partial_table, normalize_identifiers=False) + + def verify_normalized_table( schema: Schema, table: TTableSchema, capabilities: DestinationCapabilitiesContext ) -> None: diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index 7f06cdb71e..adbb34b1f0 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -441,27 +441,6 @@ def test_update_schema_table_prop_conflict(schema: Schema) -> None: assert exc_val.value.val2 == "tab_parent" -def test_update_schema_column_conflict(schema: Schema) -> None: - tab1 = utils.new_table( - "tab1", - write_disposition="append", - columns=[ - {"name": "col1", "data_type": "text", "nullable": False}, - ], - ) - schema.update_table(tab1) - tab1_u1 = deepcopy(tab1) - # simulate column that had other datatype inferred - tab1_u1["columns"]["col1"]["data_type"] = "bool" - with pytest.raises(CannotCoerceColumnException) as exc_val: - schema.update_table(tab1_u1) - assert exc_val.value.column_name == "col1" - assert exc_val.value.from_type == "bool" - assert exc_val.value.to_type == "text" - # whole column mismatch - assert exc_val.value.coerced_value is None - - def _add_preferred_types(schema: Schema) -> None: schema._settings["preferred_types"] = {} schema._settings["preferred_types"][TSimpleRegex("timestamp")] = "timestamp" diff --git a/tests/common/schema/test_merges.py b/tests/common/schema/test_merges.py index 8e0c350e7c..b76fe944b5 100644 --- a/tests/common/schema/test_merges.py +++ b/tests/common/schema/test_merges.py @@ -353,7 +353,7 @@ def test_diff_tables() -> None: assert "test" in partial["columns"] -def test_diff_tables_conflicts() -> None: +def test_tables_conflicts() -> None: # conflict on parents table: TTableSchema = { # type: ignore[typeddict-unknown-key] "name": "table", @@ -366,6 +366,8 @@ def test_diff_tables_conflicts() -> None: other = utils.new_table("table") with pytest.raises(TablePropertiesConflictException) as cf_ex: utils.diff_table("schema", table, other) + with pytest.raises(TablePropertiesConflictException) as cf_ex: + utils.ensure_compatible_tables("schema", table, other) assert cf_ex.value.table_name == "table" assert cf_ex.value.prop_name == "parent" @@ -373,6 +375,8 @@ def test_diff_tables_conflicts() -> None: other = utils.new_table("other_name") with pytest.raises(TablePropertiesConflictException) as cf_ex: utils.diff_table("schema", table, other) + with pytest.raises(TablePropertiesConflictException) as cf_ex: + utils.ensure_compatible_tables("schema", table, other) assert cf_ex.value.table_name == "table" assert cf_ex.value.prop_name == "name" @@ -380,7 +384,10 @@ def test_diff_tables_conflicts() -> None: changed = deepcopy(table) changed["columns"]["test"]["data_type"] = "bigint" with pytest.raises(CannotCoerceColumnException): - utils.diff_table("schema", table, changed) + utils.ensure_compatible_tables("schema", table, changed) + # but diff now accepts different data types + merged_table = utils.diff_table("schema", table, changed) + assert merged_table["columns"]["test"]["data_type"] == "bigint" def test_merge_tables() -> None: diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 9190225a8c..b998b78471 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -10,7 +10,6 @@ from dlt.common.pipeline import SupportsPipeline from dlt.common.destination import Destination from dlt.common.destination.reference import WithStagingDataset -from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.schema.utils import new_table diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index 29ad21941e..e09582f8a8 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -127,177 +127,6 @@ def test_pipeline_explicit_destination_credentials( ) -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_pipeline_with_sources_sharing_schema( - destination_config: DestinationTestConfiguration, -) -> None: - schema = Schema("shared") - - @dlt.source(schema=schema, max_table_nesting=1) - def source_1(): - @dlt.resource(primary_key="user_id") - def gen1(): - dlt.current.source_state()["source_1"] = True - dlt.current.resource_state()["source_1"] = True - yield {"id": "Y", "user_id": "user_y"} - - @dlt.resource(columns={"col": {"data_type": "bigint"}}) - def conflict(): - yield "conflict" - - return gen1, conflict - - @dlt.source(schema=schema, max_table_nesting=2) - def source_2(): - @dlt.resource(primary_key="id") - def gen1(): - dlt.current.source_state()["source_2"] = True - dlt.current.resource_state()["source_2"] = True - yield {"id": "X", "user_id": "user_X"} - - def gen2(): - yield from "CDE" - - @dlt.resource(columns={"col": {"data_type": "bool"}}, selected=False) - def conflict(): - yield "conflict" - - return gen2, gen1, conflict - - # all selected tables with hints should be there - discover_1 = source_1().discover_schema() - assert "gen1" in discover_1.tables - assert discover_1.tables["gen1"]["columns"]["user_id"]["primary_key"] is True - assert "data_type" not in discover_1.tables["gen1"]["columns"]["user_id"] - assert "conflict" in discover_1.tables - assert discover_1.tables["conflict"]["columns"]["col"]["data_type"] == "bigint" - - discover_2 = source_2().discover_schema() - assert "gen1" in discover_2.tables - assert "gen2" in discover_2.tables - # conflict deselected - assert "conflict" not in discover_2.tables - - p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) - p.extract([source_1(), source_2()], table_format=destination_config.table_format) - default_schema = p.default_schema - gen1_table = default_schema.tables["gen1"] - assert "user_id" in gen1_table["columns"] - assert "id" in gen1_table["columns"] - assert "conflict" in default_schema.tables - assert "gen2" in default_schema.tables - p.normalize(loader_file_format=destination_config.file_format) - assert "gen2" in default_schema.tables - p.load() - table_names = [t["name"] for t in default_schema.data_tables()] - counts = load_table_counts(p, *table_names) - assert counts == {"gen1": 2, "gen2": 3, "conflict": 1} - # both sources share the same state - assert p.state["sources"] == { - "shared": { - "source_1": True, - "resources": {"gen1": {"source_1": True, "source_2": True}}, - "source_2": True, - } - } - drop_active_pipeline_data() - - # same pipeline but enable conflict - p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) - with pytest.raises(PipelineStepFailed) as py_ex: - p.extract([source_1(), source_2().with_resources("conflict")]) - assert isinstance(py_ex.value.__context__, CannotCoerceColumnException) - - -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_many_pipelines_single_dataset(destination_config: DestinationTestConfiguration) -> None: - schema = Schema("shared") - - @dlt.source(schema=schema, max_table_nesting=1) - def source_1(): - @dlt.resource(primary_key="user_id") - def gen1(): - dlt.current.source_state()["source_1"] = True - dlt.current.resource_state()["source_1"] = True - yield {"id": "Y", "user_id": "user_y"} - - return gen1 - - @dlt.source(schema=schema, max_table_nesting=2) - def source_2(): - @dlt.resource(primary_key="id") - def gen1(): - dlt.current.source_state()["source_2"] = True - dlt.current.resource_state()["source_2"] = True - yield {"id": "X", "user_id": "user_X"} - - def gen2(): - yield from "CDE" - - return gen2, gen1 - - # load source_1 to common dataset - p = dlt.pipeline( - pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" - ) - p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") - counts = load_table_counts(p, *p.default_schema.tables.keys()) - assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() - p._wipe_working_folder() - p.deactivate() - - p = dlt.pipeline( - pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" - ) - p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") - # table_names = [t["name"] for t in p.default_schema.data_tables()] - counts = load_table_counts(p, *p.default_schema.tables.keys()) - # gen1: one record comes from source_1, 1 record from source_2 - assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() - # assert counts == {'gen1': 2, 'gen2': 3} - p._wipe_working_folder() - p.deactivate() - - # restore from destination, check state - p = dlt.pipeline( - pipeline_name="source_1_pipeline", - destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), - dataset_name="shared_dataset", - ) - p.sync_destination() - # we have our separate state - assert p.state["sources"]["shared"] == { - "source_1": True, - "resources": {"gen1": {"source_1": True}}, - } - # but the schema was common so we have the earliest one - assert "gen2" in p.default_schema.tables - p._wipe_working_folder() - p.deactivate() - - p = dlt.pipeline( - pipeline_name="source_2_pipeline", - destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), - dataset_name="shared_dataset", - ) - p.sync_destination() - # we have our separate state - assert p.state["sources"]["shared"] == { - "source_2": True, - "resources": {"gen1": {"source_2": True}}, - } - - # TODO: uncomment and finalize when we implement encoding for psycopg2 # @pytest.mark.parametrize( # "destination_config", diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 7463184be7..84e22af9ff 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -1,3 +1,4 @@ +from copy import deepcopy import pytest from fnmatch import fnmatch from typing import Dict, Iterator, List, Sequence, Tuple @@ -5,6 +6,7 @@ from dlt.common import json from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema from dlt.common.schema.utils import new_table from dlt.common.storages.exceptions import SchemaNotFoundError @@ -16,6 +18,7 @@ from dlt.extract.extract import ExtractStorage from dlt.normalize import Normalize +from dlt.normalize.validate import validate_and_update_schema from dlt.normalize.worker import group_worker_files from dlt.normalize.exceptions import NormalizeJobFailed @@ -284,6 +287,8 @@ def test_multiprocessing_row_counting( extract_cases(raw_normalize, ["github.events.load_page_1_duck"]) # use real process pool in tests with ProcessPoolExecutor(max_workers=4) as p: + # test if we get correct number of workers + assert getattr(p, "_max_workers", None) == 4 raw_normalize.run(p) # get step info step_info = raw_normalize.get_step_info(MockPipeline("multiprocessing_pipeline", True)) # type: ignore[abstract] @@ -712,6 +717,71 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type +def test_update_schema_column_conflict(rasa_normalize: Normalize) -> None: + extract_cases( + rasa_normalize, + [ + "event.event.many_load_2", + "event.event.user_load_1", + ], + ) + extract_cases( + rasa_normalize, + [ + "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2", + ], + ) + # use real process pool in tests + with ProcessPoolExecutor(max_workers=4) as p: + rasa_normalize.run(p) + + schema = rasa_normalize.schema_storage.load_schema("event") + tab1 = new_table( + "event_user", + write_disposition="append", + columns=[ + {"name": "col1", "data_type": "text", "nullable": False}, + ], + ) + validate_and_update_schema(schema, [{"event_user": [deepcopy(tab1)]}]) + assert schema.tables["event_user"]["columns"]["col1"]["data_type"] == "text" + + tab1["columns"]["col1"]["data_type"] = "bool" + tab1["columns"]["col2"] = {"name": "col2", "data_type": "text", "nullable": False} + with pytest.raises(CannotCoerceColumnException) as exc_val: + validate_and_update_schema(schema, [{"event_user": [deepcopy(tab1)]}]) + assert exc_val.value.column_name == "col1" + assert exc_val.value.from_type == "bool" + assert exc_val.value.to_type == "text" + # whole column mismatch + assert exc_val.value.coerced_value is None + # make sure col2 is not added + assert "col2" not in schema.tables["event_user"]["columns"] + + # add two updates that are conflicting + tab2 = new_table( + "event_slot", + write_disposition="append", + columns=[ + {"name": "col1", "data_type": "text", "nullable": False}, + {"name": "col2", "data_type": "text", "nullable": False}, + ], + ) + tab3 = new_table( + "event_slot", + write_disposition="append", + columns=[ + {"name": "col1", "data_type": "bool", "nullable": False}, + ], + ) + with pytest.raises(CannotCoerceColumnException) as exc_val: + validate_and_update_schema( + schema, [{"event_slot": [deepcopy(tab2)]}, {"event_slot": [deepcopy(tab3)]}] + ) + # col2 is added from first update + assert "col2" in schema.tables["event_slot"]["columns"] + + def test_removal_of_normalizer_schema_section_and_add_seen_data(raw_normalize: Normalize) -> None: extract_cases( raw_normalize, diff --git a/tests/pipeline/test_import_export_schema.py b/tests/pipeline/test_import_export_schema.py index eb36d36ba3..5eb9c664d0 100644 --- a/tests/pipeline/test_import_export_schema.py +++ b/tests/pipeline/test_import_export_schema.py @@ -1,4 +1,4 @@ -import dlt, os, pytest +import dlt, os from dlt.common.utils import uniq_id @@ -6,8 +6,6 @@ from tests.utils import TEST_STORAGE_ROOT from dlt.common.schema import Schema from dlt.common.storages.schema_storage import SchemaStorage -from dlt.common.schema.exceptions import CannotCoerceColumnException -from dlt.pipeline.exceptions import PipelineStepFailed from dlt.destinations import dummy diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 2d72e23462..aebf83d9b0 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -52,7 +52,7 @@ from dlt.pipeline.pipeline import Pipeline from tests.common.utils import TEST_SENTRY_DSN -from tests.utils import TEST_STORAGE_ROOT +from tests.utils import TEST_STORAGE_ROOT, load_table_counts from tests.extract.utils import expect_extracted_file from tests.pipeline.utils import ( assert_data_table_counts, @@ -3011,3 +3011,172 @@ def test_push_table_with_upfront_schema() -> None: copy_pipeline = dlt.pipeline(pipeline_name="push_table_copy_pipeline", destination="duckdb") info = copy_pipeline.run(data, table_name="events", schema=copy_schema) assert copy_pipeline.default_schema.version_hash != infer_hash + + +def test_pipeline_with_sources_sharing_schema() -> None: + schema = Schema("shared") + + @dlt.source(schema=schema, max_table_nesting=1) + def source_1(): + @dlt.resource(primary_key="user_id") + def gen1(): + dlt.current.source_state()["source_1"] = True + dlt.current.resource_state()["source_1"] = True + yield {"id": "Y", "user_id": "user_y"} + + @dlt.resource(columns={"value": {"data_type": "bool"}}) + def conflict(): + yield True + + return gen1, conflict + + @dlt.source(schema=schema, max_table_nesting=2) + def source_2(): + @dlt.resource(primary_key="id") + def gen1(): + dlt.current.source_state()["source_2"] = True + dlt.current.resource_state()["source_2"] = True + yield {"id": "X", "user_id": "user_X"} + + def gen2(): + yield from "CDE" + + @dlt.resource(columns={"value": {"data_type": "text"}}, selected=False) + def conflict(): + yield "indeed" + + return gen2, gen1, conflict + + # all selected tables with hints should be there + discover_1 = source_1().discover_schema() + assert "gen1" in discover_1.tables + assert discover_1.tables["gen1"]["columns"]["user_id"]["primary_key"] is True + assert "data_type" not in discover_1.tables["gen1"]["columns"]["user_id"] + assert "conflict" in discover_1.tables + assert discover_1.tables["conflict"]["columns"]["value"]["data_type"] == "bool" + + discover_2 = source_2().discover_schema() + assert "gen1" in discover_2.tables + assert "gen2" in discover_2.tables + # conflict deselected + assert "conflict" not in discover_2.tables + + p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) + p.extract([source_1(), source_2()]) + default_schema = p.default_schema + gen1_table = default_schema.tables["gen1"] + assert "user_id" in gen1_table["columns"] + assert "id" in gen1_table["columns"] + assert "conflict" in default_schema.tables + assert "gen2" in default_schema.tables + p.normalize() + assert "gen2" in default_schema.tables + assert default_schema.tables["conflict"]["columns"]["value"]["data_type"] == "bool" + p.load() + table_names = [t["name"] for t in default_schema.data_tables()] + counts = load_table_counts(p, *table_names) + assert counts == {"gen1": 2, "gen2": 3, "conflict": 1} + # both sources share the same state + assert p.state["sources"] == { + "shared": { + "source_1": True, + "resources": {"gen1": {"source_1": True, "source_2": True}}, + "source_2": True, + } + } + + # same pipeline but enable conflict + p.extract([source_2().with_resources("conflict")]) + p.normalize() + assert default_schema.tables["conflict"]["columns"]["value"]["data_type"] == "text" + with pytest.raises(PipelineStepFailed): + # will generate failed job on type that does not match + p.load() + counts = load_table_counts(p, "conflict") + assert counts == {"conflict": 1} + + # alter table in duckdb + with p.sql_client() as client: + client.execute_sql("ALTER TABLE conflict ALTER value TYPE VARCHAR;") + p.run([source_2().with_resources("conflict")]) + counts = load_table_counts(p, "conflict") + assert counts == {"conflict": 2} + print(p.dataset().conflict.df()) + + +def test_many_pipelines_single_dataset() -> None: + schema = Schema("shared") + + @dlt.source(schema=schema, max_table_nesting=1) + def source_1(): + @dlt.resource(primary_key="user_id") + def gen1(): + dlt.current.source_state()["source_1"] = True + dlt.current.resource_state()["source_1"] = True + yield {"id": "Y", "user_id": "user_y"} + + return gen1 + + @dlt.source(schema=schema, max_table_nesting=2) + def source_2(): + @dlt.resource(primary_key="id") + def gen1(): + dlt.current.source_state()["source_2"] = True + dlt.current.resource_state()["source_2"] = True + yield {"id": "X", "user_id": "user_X"} + + def gen2(): + yield from "CDE" + + return gen2, gen1 + + # load source_1 to common dataset + p = dlt.pipeline( + pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) + p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") + counts = load_table_counts(p, *p.default_schema.tables.keys()) + assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() + p._wipe_working_folder() + p.deactivate() + + p = dlt.pipeline( + pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) + p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") + # table_names = [t["name"] for t in p.default_schema.data_tables()] + counts = load_table_counts(p, *p.default_schema.tables.keys()) + # gen1: one record comes from source_1, 1 record from source_2 + assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() + # assert counts == {'gen1': 2, 'gen2': 3} + p._wipe_working_folder() + p.deactivate() + + # restore from destination, check state + p = dlt.pipeline( + pipeline_name="source_1_pipeline", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), + dataset_name="shared_dataset", + ) + p.sync_destination() + # we have our separate state + assert p.state["sources"]["shared"] == { + "source_1": True, + "resources": {"gen1": {"source_1": True}}, + } + # but the schema was common so we have the earliest one + assert "gen2" in p.default_schema.tables + p._wipe_working_folder() + p.deactivate() + + p = dlt.pipeline( + pipeline_name="source_2_pipeline", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), + dataset_name="shared_dataset", + ) + p.sync_destination() + # we have our separate state + assert p.state["sources"]["shared"] == { + "source_2": True, + "resources": {"gen1": {"source_2": True}}, + }