From fbf0ef4c76e43dbf243276c601bc854d6a9ccda4 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:40:57 +0200 Subject: [PATCH 1/3] Feat/1492 extend timestamp config (#1669) * feat: add timezone flag to configure timestamp data * fix: delete timezone init * test: add duckdb timestamps with timezone * test: fix resource hints for timestamp * test: correct duckdb timestamps * test: timezone tests for parquet files * exp: add notebook with timestamp exploration * test: refactor timestamp tests * test: simplified tests and extended experiments * exp: timestamp exp for duckdb and parquet * fix: add pyarrow reflection for timezone flag * fix lint errors * fix: CI/CD move tests pyarrow module * fix: pyarrow timezone defaults true * refactor: typemapper signatures * fix: duckdb timestamp config * docs: updated duckdb.md timestamps * fix: revert duckdb timestamp defaults * fix: restore duckdb timestamp default * fix: duckdb timestamp mapper * fix: delete notebook * docs: added timestamp and timezone section * refactor: duckdb precision exception message * feat: postgres timestamp timezone config * fix: postgres timestamp precision * fix: postgres timezone false case * feat: add snowflake timezone and precision flag * test: postgres invalid timestamp precision * test: unified timestamp invalid precision * test: unified column flag timezone * chore: add warn log for unsupported timezone or precision flag * docs: timezone and precision flags for timestamps * fix: none case error * docs: add duckdb default precision * fix: typing errors * rebase: formatted files from upstream devel * fix: warning message and reference TODO * test: delete duplicated input_data array * docs: moved timestamp config to data types section * fix: lint and format * fix: lint local errors --- dlt/common/libs/pyarrow.py | 9 +- dlt/common/schema/typing.py | 1 + dlt/destinations/impl/athena/athena.py | 14 +- dlt/destinations/impl/bigquery/bigquery.py | 8 +- .../impl/clickhouse/clickhouse.py | 6 +- .../impl/databricks/databricks.py | 15 +- dlt/destinations/impl/dremio/dremio.py | 12 +- dlt/destinations/impl/duckdb/duck.py | 43 +++-- .../impl/lancedb/lancedb_client.py | 24 +-- dlt/destinations/impl/mssql/mssql.py | 15 +- dlt/destinations/impl/postgres/postgres.py | 42 ++++- dlt/destinations/impl/redshift/redshift.py | 9 +- dlt/destinations/impl/snowflake/snowflake.py | 41 ++++- dlt/destinations/job_client_impl.py | 11 +- dlt/destinations/type_mapping.py | 56 +++++-- .../docs/dlt-ecosystem/destinations/duckdb.md | 39 ++++- .../dlt-ecosystem/destinations/postgres.md | 21 +++ .../dlt-ecosystem/destinations/snowflake.md | 21 +++ tests/load/pipeline/test_pipelines.py | 148 ++++++++++++++++++ tests/pipeline/test_pipeline.py | 17 +- tests/pipeline/test_pipeline_extra.py | 80 ++++++++++ 21 files changed, 536 insertions(+), 96 deletions(-) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 9d3e97421c..e9dcfaf095 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -54,7 +54,10 @@ def get_py_arrow_datatype( elif column_type == "bool": return pyarrow.bool_() elif column_type == "timestamp": - return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz) + # sets timezone to None when timezone hint is false + timezone = tz if column.get("timezone", True) else None + precision = column.get("precision") or caps.timestamp_precision + return get_py_arrow_timestamp(precision, timezone) elif column_type == "bigint": return get_pyarrow_int(column.get("precision")) elif column_type == "binary": @@ -139,6 +142,10 @@ def get_column_type_from_py_arrow(dtype: pyarrow.DataType) -> TColumnType: precision = 6 else: precision = 9 + + if dtype.tz is None: + return dict(data_type="timestamp", precision=precision, timezone=False) + return dict(data_type="timestamp", precision=precision) elif pyarrow.types.is_date(dtype): return dict(data_type="date") diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 284c55caac..a81e9046a9 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -94,6 +94,7 @@ class TColumnType(TypedDict, total=False): data_type: Optional[TDataType] precision: Optional[int] scale: Optional[int] + timezone: Optional[bool] class TColumnSchemaBase(TColumnType, total=False): diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index a5a8ae2562..c4a9bab212 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -104,9 +104,9 @@ class AthenaTypeMapper(TypeMapper): def __init__(self, capabilities: DestinationCapabilitiesContext): super().__init__(capabilities) - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") + table_format = table.get("table_format") if precision is None: return "bigint" if precision <= 8: @@ -403,9 +403,9 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(hive_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: return ( - f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" + f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table)}" ) def _iceberg_partition_clause(self, partition_hints: Optional[Dict[str, str]]) -> str: @@ -429,9 +429,9 @@ def _get_table_update_sql( # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables table = self.prepare_load_table(table_name, self.in_staging_mode) - table_format = table.get("table_format") + is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" - columns = ", ".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + columns = ", ".join([self._get_column_def_sql(c, table) for c in new_columns]) # create unique tag for iceberg table so it is never recreated in the same folder # athena requires some kind of special cleaning (or that is a bug) so we cannot refresh diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 1dd4c727be..9bc555bd0d 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -90,9 +90,9 @@ class BigQueryTypeMapper(TypeMapper): "TIME": "time", } - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> str: + def to_db_decimal_type(self, column: TColumnSchema) -> str: # Use BigQuery's BIGNUMERIC for large precision decimals - precision, scale = self.decimal_precision(precision, scale) + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) if precision > 38 or scale > 9: return "BIGNUMERIC(%i,%i)" % (precision, scale) return "NUMERIC(%i,%i)" % (precision, scale) @@ -417,10 +417,10 @@ def _get_info_schema_columns_query( return query, folded_table_names - def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, column: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(column["name"]) column_def_sql = ( - f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(column, table)} {self._gen_not_null(column.get('nullable', True))}" ) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 282fbaf338..038735a84b 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -293,7 +293,7 @@ def _create_merge_followup_jobs( ) -> List[FollowupJobRequest]: return [ClickHouseMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: # Build column definition. # The primary key and sort order definition is defined outside column specification. hints_ = " ".join( @@ -307,9 +307,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non # Alter table statements only accept `Nullable` modifiers. # JSON type isn't nullable in ClickHouse. type_with_nullability_modifier = ( - f"Nullable({self.type_mapper.to_db_type(c)})" + f"Nullable({self.type_mapper.to_db_type(c,table)})" if c.get("nullable", True) - else self.type_mapper.to_db_type(c) + else self.type_mapper.to_db_type(c, table) ) return ( diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 614e6e97c5..0c19984b4c 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -68,9 +68,8 @@ class DatabricksTypeMapper(TypeMapper): "wei": "DECIMAL(%i,%i)", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "BIGINT" if precision <= 8: @@ -323,10 +322,12 @@ def _create_merge_followup_jobs( return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause - return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return [ + "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) + ] def _get_table_update_sql( self, @@ -351,10 +352,10 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def _get_storage_table_query_columns(self) -> List[str]: diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 149d106dcd..91dc64f113 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -195,10 +195,10 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def _create_merge_followup_jobs( @@ -207,9 +207,13 @@ def _create_merge_followup_jobs( return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: - return ["ADD COLUMNS (" + ", ".join(self._get_column_def_sql(c) for c in new_columns) + ")"] + return [ + "ADD COLUMNS (" + + ", ".join(self._get_column_def_sql(c, table) for c in new_columns) + + ")" + ] def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 3d5905ff40..d5065f5bdd 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -62,9 +62,8 @@ class DuckDbTypeMapper(TypeMapper): "TIMESTAMP_NS": "timestamp", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "BIGINT" # Precision is number of bits @@ -83,19 +82,39 @@ def to_db_integer_type( ) def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone and precision is not None: + raise TerminalValueError( + f"DuckDB does not support both timezone and precision for column '{column_name}' in" + f" table '{table_name}'. To resolve this issue, either set timezone to False or" + " None, or use the default precision." + ) + + if timezone: + return "TIMESTAMP WITH TIME ZONE" + elif timezone is not None: # condition for when timezone is False given that none is falsy + return "TIMESTAMP" + if precision is None or precision == 6: - return super().to_db_datetime_type(precision, table_format) - if precision == 0: + return None + elif precision == 0: return "TIMESTAMP_S" - if precision == 3: + elif precision == 3: return "TIMESTAMP_MS" - if precision == 9: + elif precision == 9: return "TIMESTAMP_NS" + raise TerminalValueError( - f"timestamp with {precision} decimals after seconds cannot be mapped into duckdb" - " TIMESTAMP type" + f"DuckDB does not support precision '{precision}' for '{column_name}' in table" + f" '{table_name}'" ) def from_db_type( @@ -162,7 +181,7 @@ def create_load_job( job = DuckDbCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -170,7 +189,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _from_db_type( diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 78a37952b9..02240b8f93 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -41,7 +41,7 @@ LoadJob, ) from dlt.common.pendulum import timedelta -from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TColumnSchema from dlt.common.schema.typing import ( TColumnType, TTableFormat, @@ -105,21 +105,27 @@ class LanceDBTypeMapper(TypeMapper): pa.date32(): "date", } - def to_db_decimal_type( - self, precision: Optional[int], scale: Optional[int] - ) -> pa.Decimal128Type: - precision, scale = self.decimal_precision(precision, scale) + def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) return pa.decimal128(precision, scale) def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> pa.TimestampType: + column_name = column.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + if timezone is not None or precision is not None: + logger.warning( + "LanceDB does not currently support column flags for timezone or precision." + f" These flags were used in column '{column_name}'." + ) unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.timestamp(unit, "UTC") - def to_db_time_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> pa.Time64Type: + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> pa.Time64Type: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.time64(unit) diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 750dc93a10..a7e796b2d8 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -59,9 +59,8 @@ class MsSqlTypeMapper(TypeMapper): "int": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" if precision <= 8: @@ -166,20 +165,18 @@ def _create_merge_followup_jobs( return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because mssql requires multiple columns in a single ADD COLUMN clause - return [ - "ADD \n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) - ] + return ["ADD \n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns db_type = "nvarchar(%i)" % (c.get("precision") or 900) else: - db_type = self.type_mapper.to_db_type(c) + db_type = self.type_mapper.to_db_type(c, table) hints_str = " ".join( self.active_hints.get(h, "") diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index a832bfe07f..5777e46c90 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -66,9 +66,8 @@ class PostgresTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" # Precision is number of bits @@ -82,6 +81,39 @@ def to_db_integer_type( f"bigint with {precision} bits precision cannot be mapped into postgres integer type" ) + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is None and precision is None: + return None + + timestamp = "timestamp" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 6: + timestamp += f" ({precision})" + else: + raise TerminalValueError( + f"Postgres does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + # append timezone part + if timezone is None or timezone: # timezone True and None + timestamp += " with time zone" + else: # timezone is explicitly False + timestamp += " without time zone" + + return timestamp + def from_db_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None ) -> TColumnType: @@ -233,7 +265,7 @@ def create_load_job( job = PostgresCsvCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -241,7 +273,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _create_replace_followup_jobs( diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 0e201dc4e0..9bba60af07 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -82,9 +82,8 @@ class RedshiftTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" if precision <= 16: @@ -243,7 +242,7 @@ def _create_merge_followup_jobs( ) -> List[FollowupJobRequest]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() @@ -251,7 +250,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def create_load_job( diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 6688b5bc17..247b3233d0 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -18,7 +18,7 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat - +from dlt.common.exceptions import TerminalValueError from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import TLoaderFileFormat @@ -77,6 +77,36 @@ def from_db_type( return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(db_type, precision, scale) + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is None and precision is None: + return None + + timestamp = "TIMESTAMP_TZ" + + if timezone is not None and not timezone: # explicitaly handles timezone False + timestamp = "TIMESTAMP_NTZ" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 9: + timestamp += f"({precision})" + else: + raise TerminalValueError( + f"Snowflake does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + return timestamp + class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( @@ -289,12 +319,11 @@ def create_load_job( return job def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because snowflake requires multiple columns in a single ADD COLUMN clause return [ - "ADD COLUMN\n" - + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) + "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] def _get_table_update_sql( @@ -320,10 +349,10 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 1d6403a2c8..3026baf753 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -517,10 +517,10 @@ def _build_schema_update_sql( return sql_updates, schema_update def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" - return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns] + return [f"ADD COLUMN {self._get_column_def_sql(c, table)}" for c in new_columns] def _make_create_table(self, qualified_name: str, table: TTableSchema) -> str: not_exists_clause = " " @@ -537,17 +537,16 @@ def _get_table_update_sql( # build sql qualified_name = self.sql_client.make_qualified_table_name(table_name) table = self.prepare_load_table(table_name) - table_format = table.get("table_format") sql_result: List[str] = [] if not generate_alter: # build CREATE sql = self._make_create_table(qualified_name, table) + " (\n" - sql += ",\n".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns]) sql += ")" sql_result.append(sql) else: sql_base = f"ALTER TABLE {qualified_name}\n" - add_column_statements = self._make_add_column_sql(new_columns, table_format) + add_column_statements = self._make_add_column_sql(new_columns, table) if self.capabilities.alter_add_multi_column: column_sql = ",\n" sql_result.append(sql_base + column_sql.join(add_column_statements)) @@ -582,7 +581,7 @@ def _get_table_update_sql( return sql_result @abstractmethod - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: pass @staticmethod diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index dcd938b33c..5ac43e4f1f 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,6 +1,13 @@ from typing import Tuple, ClassVar, Dict, Optional -from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType, TTableFormat +from dlt.common import logger +from dlt.common.schema.typing import ( + TColumnSchema, + TDataType, + TColumnType, + TTableFormat, + TTableSchema, +) from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.utils import without_none @@ -20,39 +27,54 @@ class TypeMapper: def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: self.capabilities = capabilities - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: # Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.) return self.sct_to_unbound_dbt["bigint"] def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> str: # Override in subclass if db supports other timestamp types (e.g. with different time resolutions) + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is not None or precision is not None: + message = ( + "Column flags for timezone or precision are not yet supported in this" + " destination. One or both of these flags were used in column" + f" '{column.get('name')}'." + ) + # TODO: refactor lancedb and wevavite to make table object required + if table: + message += f" in table '{table.get('name')}'." + + logger.warning(message) + return None - def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: # Override in subclass if db supports other time types (e.g. with different time resolutions) return None - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> str: - precision_tup = self.decimal_precision(precision, scale) + def to_db_decimal_type(self, column: TColumnSchema) -> str: + precision_tup = self.decimal_precision(column.get("precision"), column.get("scale")) if not precision_tup or "decimal" not in self.sct_to_dbt: return self.sct_to_unbound_dbt["decimal"] return self.sct_to_dbt["decimal"] % (precision_tup[0], precision_tup[1]) - def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: - precision, scale = column.get("precision"), column.get("scale") + # TODO: refactor lancedb and wevavite to make table object required + def to_db_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: sc_t = column["data_type"] if sc_t == "bigint": - db_t = self.to_db_integer_type(precision, table_format) + db_t = self.to_db_integer_type(column, table) elif sc_t == "timestamp": - db_t = self.to_db_datetime_type(precision, table_format) + db_t = self.to_db_datetime_type(column, table) elif sc_t == "time": - db_t = self.to_db_time_type(precision, table_format) + db_t = self.to_db_time_type(column, table) elif sc_t == "decimal": - db_t = self.to_db_decimal_type(precision, scale) + db_t = self.to_db_decimal_type(column) else: db_t = None if db_t: @@ -61,14 +83,16 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) - bounded_template = self.sct_to_dbt.get(sc_t) if not bounded_template: return self.sct_to_unbound_dbt[sc_t] - precision_tuple = self.precision_tuple_or_default(sc_t, precision, scale) + precision_tuple = self.precision_tuple_or_default(sc_t, column) if not precision_tuple: return self.sct_to_unbound_dbt[sc_t] return self.sct_to_dbt[sc_t] % precision_tuple def precision_tuple_or_default( - self, data_type: TDataType, precision: Optional[int], scale: Optional[int] + self, data_type: TDataType, column: TColumnSchema ) -> Optional[Tuple[int, ...]]: + precision = column.get("precision") + scale = column.get("scale") if data_type in ("timestamp", "time"): if precision is None: return None # Use default which is usually the max diff --git a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md index 19cef92f9d..4b8ecec4ca 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md @@ -35,6 +35,42 @@ All write dispositions are supported. ## Data loading `dlt` will load data using large INSERT VALUES statements by default. Loading is multithreaded (20 threads by default). If you are okay with installing `pyarrow`, we suggest switching to `parquet` as the file format. Loading is faster (and also multithreaded). +### Data types +`duckdb` supports various [timestamp types](https://duckdb.org/docs/sql/data_types/timestamp.html). These can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: supported precision values are 0, 3, 6, and 9 for fractional seconds. Note that `timezone` and `precision` cannot be used together; attempting to combine them will result in an error. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP WITH TIME ZONE` (`TIMESTAMPTZ`). + +#### Example precision: TIMESTAMP_MS + +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="duckdb") +pipeline.run(events()) +``` + +#### Example timezone: TIMESTAMP + +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + +pipeline = dlt.pipeline(destination="duckdb") +pipeline.run(events()) +``` + ### Names normalization `dlt` uses the standard **snake_case** naming convention to keep identical table and column identifiers across all destinations. If you want to use the **duckdb** wide range of characters (i.e., emojis) for table and column names, you can switch to the **duck_case** naming convention, which accepts almost any string as an identifier: * `\n` `\r` and `"` are translated to `_` @@ -77,7 +113,8 @@ to disable tz adjustments. ::: ## Supported column hints -`duckdb` may create unique indexes for all columns with `unique` hints, but this behavior **is disabled by default** because it slows the loading down significantly. + +`duckdb` can create unique indexes for columns with `unique` hints. However, **this feature is disabled by default** as it can significantly slow down data loading. ## Destination Configuration diff --git a/docs/website/docs/dlt-ecosystem/destinations/postgres.md b/docs/website/docs/dlt-ecosystem/destinations/postgres.md index 1281298312..e506eb79fe 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/postgres.md +++ b/docs/website/docs/dlt-ecosystem/destinations/postgres.md @@ -82,6 +82,27 @@ If you set the [`replace` strategy](../../general-usage/full-loading.md) to `sta ## Data loading `dlt` will load data using large INSERT VALUES statements by default. Loading is multithreaded (20 threads by default). +### Data types +`postgres` supports various timestamp types, which can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: allows you to specify the number of decimal places for fractional seconds, ranging from 0 to 6. It can be used in combination with the `timezone` flag. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP WITHOUT TIME ZONE`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP WITH TIME ZONE`. + +#### Example precision and timezone: TIMESTAMP (3) WITHOUT TIME ZONE +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3, "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="postgres") +pipeline.run(events()) +``` + ### Fast loading with arrow tables and csv You can use [arrow tables](../verified-sources/arrow-pandas.md) and [csv](../file-formats/csv.md) to quickly load tabular data. Pick the `csv` loader file format like below diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 57e6db311d..f4d5a53d36 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -143,6 +143,27 @@ The data is loaded using an internal Snowflake stage. We use the `PUT` command a keep_staged_files = false ``` +### Data types +`snowflake` supports various timestamp types, which can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: allows you to specify the number of decimal places for fractional seconds, ranging from 0 to 9. It can be used in combination with the `timezone` flag. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP_NTZ`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP_TZ`. + +#### Example precision and timezone: TIMESTAMP_NTZ(3) +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3, "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="snowflake") +pipeline.run(events()) +``` + ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default * [parquet](../file-formats/parquet.md) is supported diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 81c9292570..2792cec085 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -17,6 +17,7 @@ from dlt.common.schema.utils import new_table from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id +from dlt.common.exceptions import TerminalValueError from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations import filesystem, redshift @@ -1146,3 +1147,150 @@ def _data(): dataset_name=dataset_name, ) return p, _data + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "postgres", "snowflake"]), + ids=lambda x: x.name, +) +def test_dest_column_invalid_timestamp_precision( + destination_config: DestinationTestConfiguration, +) -> None: + invalid_precision = 10 + + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": invalid_precision}}, + primary_key="event_id", + ) + def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + + pipeline = destination_config.setup_pipeline(uniq_id()) + + with pytest.raises((TerminalValueError, PipelineStepFailed)): + pipeline.run(events()) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "snowflake", "postgres"]), + ids=lambda x: x.name, +) +def test_dest_column_hint_timezone(destination_config: DestinationTestConfiguration) -> None: + destination = destination_config.destination + + input_data = [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + output_values = [ + "2024-07-30T10:00:00.123000", + "2024-07-30T08:00:00.123456", + "2024-07-30T10:00:00.123456", + ] + + output_map = { + "postgres": { + "tables": { + "events_timezone_off": { + "timestamp_type": "timestamp without time zone", + "timestamp_values": output_values, + }, + "events_timezone_on": { + "timestamp_type": "timestamp with time zone", + "timestamp_values": output_values, + }, + "events_timezone_unset": { + "timestamp_type": "timestamp with time zone", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='experiments'" + " AND table_name = '%s' AND column_name = 'event_tstamp'" + ), + }, + "snowflake": { + "tables": { + "EVENTS_TIMEZONE_OFF": { + "timestamp_type": "TIMESTAMP_NTZ", + "timestamp_values": output_values, + }, + "EVENTS_TIMEZONE_ON": { + "timestamp_type": "TIMESTAMP_TZ", + "timestamp_values": output_values, + }, + "EVENTS_TIMEZONE_UNSET": { + "timestamp_type": "TIMESTAMP_TZ", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='EXPERIMENTS'" + " AND table_name = '%s' AND column_name = 'EVENT_TSTAMP'" + ), + }, + "duckdb": { + "tables": { + "events_timezone_off": { + "timestamp_type": "TIMESTAMP", + "timestamp_values": output_values, + }, + "events_timezone_on": { + "timestamp_type": "TIMESTAMP WITH TIME ZONE", + "timestamp_values": output_values, + }, + "events_timezone_unset": { + "timestamp_type": "TIMESTAMP WITH TIME ZONE", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='experiments'" + " AND table_name = '%s' AND column_name = 'event_tstamp'" + ), + }, + } + + # table: events_timezone_off + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", + ) + def events_timezone_off(): + yield input_data + + # table: events_timezone_on + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True}}, + primary_key="event_id", + ) + def events_timezone_on(): + yield input_data + + # table: events_timezone_unset + @dlt.resource( + primary_key="event_id", + ) + def events_timezone_unset(): + yield input_data + + pipeline = destination_config.setup_pipeline( + f"{destination}_" + uniq_id(), dataset_name="experiments" + ) + + pipeline.run([events_timezone_off(), events_timezone_on(), events_timezone_unset()]) + + with pipeline.sql_client() as client: + for t in output_map[destination]["tables"].keys(): # type: ignore + # check data type + column_info = client.execute_sql(output_map[destination]["query_data_type"] % t) + assert column_info[0][0] == output_map[destination]["tables"][t]["timestamp_type"] # type: ignore + # check timestamp data + rows = client.execute_sql(f"SELECT event_tstamp FROM {t} ORDER BY event_id") + + values = [r[0].strftime("%Y-%m-%dT%H:%M:%S.%f") for r in rows] + assert values == output_map[destination]["tables"][t]["timestamp_values"] # type: ignore diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 027a2b4e72..918f9beab9 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -29,7 +29,7 @@ DestinationTerminalException, UnknownDestinationModule, ) -from dlt.common.exceptions import PipelineStateNotAvailable +from dlt.common.exceptions import PipelineStateNotAvailable, TerminalValueError from dlt.common.pipeline import LoadInfo, PipelineContext from dlt.common.runtime.collector import LogCollector from dlt.common.schema.exceptions import TableIdentifiersFrozen @@ -2729,3 +2729,18 @@ def assert_imported_file( extract_info.metrics[extract_info.loads_ids[0]][0]["table_metrics"][table_name].items_count == expected_rows ) + + +def test_duckdb_column_invalid_timestamp() -> None: + # DuckDB does not have timestamps with timezone and precision + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True, "precision": 3}}, + primary_key="event_id", + ) + def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + + pipeline = dlt.pipeline(destination="duckdb") + + with pytest.raises((TerminalValueError, PipelineStepFailed)): + pipeline.run(events()) diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index d3e44198b4..c757959bec 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -22,6 +22,7 @@ class BaseModel: # type: ignore[no-redef] from dlt.common import json, pendulum from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.common.runtime.collector import ( AliveCollector, EnlightenCollector, @@ -599,3 +600,82 @@ def test_pick_matching_file_format(test_storage: FileStorage) -> None: files = test_storage.list_folder_files("user_data_csv/object") assert len(files) == 1 assert files[0].endswith("csv") + + +def test_filesystem_column_hint_timezone() -> None: + import pyarrow.parquet as pq + import posixpath + + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "_storage" + + # talbe: events_timezone_off + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", + ) + def events_timezone_off(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + # talbe: events_timezone_on + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True}}, + primary_key="event_id", + ) + def events_timezone_on(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + # talbe: events_timezone_unset + @dlt.resource( + primary_key="event_id", + ) + def events_timezone_unset(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + pipeline = dlt.pipeline(destination="filesystem") + + pipeline.run( + [events_timezone_off(), events_timezone_on(), events_timezone_unset()], + loader_file_format="parquet", + ) + + client: FilesystemClient = pipeline.destination_client() # type: ignore[assignment] + + expected_results = { + "events_timezone_off": None, + "events_timezone_on": "UTC", + "events_timezone_unset": "UTC", + } + + for t in expected_results.keys(): + events_glob = posixpath.join(client.dataset_path, f"{t}/*") + events_files = client.fs_client.glob(events_glob) + + with open(events_files[0], "rb") as f: + table = pq.read_table(f) + + # convert the timestamps to strings + timestamps = [ + ts.as_py().strftime("%Y-%m-%dT%H:%M:%S.%f") for ts in table.column("event_tstamp") + ] + assert timestamps == [ + "2024-07-30T10:00:00.123000", + "2024-07-30T08:00:00.123456", + "2024-07-30T10:00:00.123456", + ] + + # check if the Parquet file contains timezone information + schema = table.schema + field = schema.field("event_tstamp") + assert field.type.tz == expected_results[t] From 1723faa92717090f2c0d28f471d3772f647130e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Willi=20M=C3=BCller?= Date: Fri, 30 Aug 2024 23:01:43 +0530 Subject: [PATCH 2/3] Fix/1571 Incremental: Optionally raise, load, or ignore raise records with cursor_path missing or None value (#1576) * allows specification of what happens on cursor_path missing or cursor_path having the value None: raise differentiated exceptions, exclude row, or include row. * Documents handling None values at the incremental cursor * fixes incremental extract crashing if one record has cursor_path = None * test that add_map can be used to transform items before the incremental function is called * Unifies treating of None values for python Objects (including pydantic), pandas, and arrow --------- Co-authored-by: Marcin Rudolf --- dlt/extract/incremental/__init__.py | 18 +- dlt/extract/incremental/exceptions.py | 19 +- dlt/extract/incremental/transform.py | 89 +++- dlt/extract/incremental/typing.py | 3 +- .../docs/general-usage/incremental-loading.md | 69 ++- tests/extract/test_incremental.py | 460 +++++++++++++++++- tests/pipeline/test_pipeline_extra.py | 1 - 7 files changed, 624 insertions(+), 35 deletions(-) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index c1117370b5..343a737c07 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -35,7 +35,12 @@ IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, ) -from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc +from dlt.extract.incremental.typing import ( + IncrementalColumnState, + TCursorValue, + LastValueFunc, + OnCursorValueMissing, +) from dlt.extract.pipe import Pipe from dlt.extract.items import SupportsPipe, TTableHintTemplate, ItemTransform from dlt.extract.incremental.transform import ( @@ -81,7 +86,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa >>> info = p.run(r, destination="duckdb") Args: - cursor_path: The name or a JSON path to an cursor field. Uses the same names of fields as in your JSON document, before they are normalized to store in the database. + cursor_path: The name or a JSON path to a cursor field. Uses the same names of fields as in your JSON document, before they are normalized to store in the database. initial_value: Optional value used for `last_value` when no state is available, e.g. on the first run of the pipeline. If not provided `last_value` will be `None` on the first run. last_value_func: Callable used to determine which cursor value to save in state. It is called with a list of the stored state value and all cursor vals from currently processing items. Default is `max` primary_key: Optional primary key used to deduplicate data. If not provided, a primary key defined by the resource will be used. Pass a tuple to define a compound key. Pass empty tuple to disable unique checks @@ -95,6 +100,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa specified range of data. Currently Airflow scheduler is detected: "data_interval_start" and "data_interval_end" are taken from the context and passed Incremental class. The values passed explicitly to Incremental will be ignored. Note that if logical "end date" is present then also "end_value" will be set which means that resource state is not used and exactly this range of date will be loaded + on_cursor_value_missing: Specify what happens when the cursor_path does not exist in a record or a record has `None` at the cursor_path: raise, include, exclude """ # this is config/dataclass so declare members @@ -104,6 +110,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa end_value: Optional[Any] = None row_order: Optional[TSortOrder] = None allow_external_schedulers: bool = False + on_cursor_value_missing: OnCursorValueMissing = "raise" # incremental acting as empty EMPTY: ClassVar["Incremental[Any]"] = None @@ -118,6 +125,7 @@ def __init__( end_value: Optional[TCursorValue] = None, row_order: Optional[TSortOrder] = None, allow_external_schedulers: bool = False, + on_cursor_value_missing: OnCursorValueMissing = "raise", ) -> None: # make sure that path is valid if cursor_path: @@ -133,6 +141,11 @@ def __init__( self._primary_key: Optional[TTableHintTemplate[TColumnNames]] = primary_key self.row_order = row_order self.allow_external_schedulers = allow_external_schedulers + if on_cursor_value_missing not in ["raise", "include", "exclude"]: + raise ValueError( + f"Unexpected argument for on_cursor_value_missing. Got {on_cursor_value_missing}" + ) + self.on_cursor_value_missing = on_cursor_value_missing self._cached_state: IncrementalColumnState = None """State dictionary cached on first access""" @@ -171,6 +184,7 @@ def _make_transforms(self) -> None: self.last_value_func, self._primary_key, set(self._cached_state["unique_hashes"]), + self.on_cursor_value_missing, ) @classmethod diff --git a/dlt/extract/incremental/exceptions.py b/dlt/extract/incremental/exceptions.py index a5f94c2974..973d3b6585 100644 --- a/dlt/extract/incremental/exceptions.py +++ b/dlt/extract/incremental/exceptions.py @@ -5,12 +5,27 @@ class IncrementalCursorPathMissing(PipeException): - def __init__(self, pipe_name: str, json_path: str, item: TDataItem, msg: str = None) -> None: + def __init__( + self, pipe_name: str, json_path: str, item: TDataItem = None, msg: str = None + ) -> None: + self.json_path = json_path + self.item = item + msg = ( + msg + or f"Cursor element with JSON path `{json_path}` was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document because they can be different from the names you see in database." + ) + super().__init__(pipe_name, msg) + + +class IncrementalCursorPathHasValueNone(PipeException): + def __init__( + self, pipe_name: str, json_path: str, item: TDataItem = None, msg: str = None + ) -> None: self.json_path = json_path self.item = item msg = ( msg - or f"Cursor element with JSON path {json_path} was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document - if those are different from the names you see in database." + or f"Cursor element with JSON path `{json_path}` has the value `None` in extracted data item. All data items must contain a value != None. Construct the incremental with on_cursor_value_none='include' if you want to include such rows" ) super().__init__(pipe_name, msg) diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 0ac9fdf520..eb448d4266 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -1,5 +1,5 @@ -from datetime import datetime, date # noqa: I251 -from typing import Any, Optional, Set, Tuple, List +from datetime import datetime # noqa: I251 +from typing import Any, Optional, Set, Tuple, List, Type from dlt.common.exceptions import MissingDependencyException from dlt.common.utils import digest128 @@ -11,8 +11,9 @@ IncrementalCursorInvalidCoercion, IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, + IncrementalCursorPathHasValueNone, ) -from dlt.extract.incremental.typing import TCursorValue, LastValueFunc +from dlt.extract.incremental.typing import TCursorValue, LastValueFunc, OnCursorValueMissing from dlt.extract.utils import resolve_column_value from dlt.extract.items import TTableHintTemplate from dlt.common.schema.typing import TColumnNames @@ -55,6 +56,7 @@ def __init__( last_value_func: LastValueFunc[TCursorValue], primary_key: Optional[TTableHintTemplate[TColumnNames]], unique_hashes: Set[str], + on_cursor_value_missing: OnCursorValueMissing = "raise", ) -> None: self.resource_name = resource_name self.cursor_path = cursor_path @@ -67,6 +69,7 @@ def __init__( self.primary_key = primary_key self.unique_hashes = unique_hashes self.start_unique_hashes = set(unique_hashes) + self.on_cursor_value_missing = on_cursor_value_missing # compile jsonpath self._compiled_cursor_path = compile_path(cursor_path) @@ -116,21 +119,39 @@ class JsonIncremental(IncrementalTransform): def find_cursor_value(self, row: TDataItem) -> Any: """Finds value in row at cursor defined by self.cursor_path. - Will use compiled JSONPath if present, otherwise it reverts to column search if row is dict + Will use compiled JSONPath if present. + Otherwise, reverts to field access if row is dict, Pydantic model, or of other class. """ - row_value: Any = None + key_exc: Type[Exception] = IncrementalCursorPathHasValueNone if self._compiled_cursor_path: - row_values = find_values(self._compiled_cursor_path, row) - if row_values: - row_value = row_values[0] + # ignores the other found values, e.g. when the path is $data.items[*].created_at + try: + row_value = find_values(self._compiled_cursor_path, row)[0] + except IndexError: + # empty list so raise a proper exception + row_value = None + key_exc = IncrementalCursorPathMissing else: try: - row_value = row[self.cursor_path] - except Exception: - pass - if row_value is None: - raise IncrementalCursorPathMissing(self.resource_name, self.cursor_path, row) - return row_value + try: + row_value = row[self.cursor_path] + except TypeError: + # supports Pydantic models and other classes + row_value = getattr(row, self.cursor_path) + except (KeyError, AttributeError): + # attr not found so raise a proper exception + row_value = None + key_exc = IncrementalCursorPathMissing + + # if we have a value - return it + if row_value is not None: + return row_value + + if self.on_cursor_value_missing == "raise": + # raise missing path or None value exception + raise key_exc(self.resource_name, self.cursor_path, row) + elif self.on_cursor_value_missing == "exclude": + return None def __call__( self, @@ -144,6 +165,12 @@ def __call__( return row, False, False row_value = self.find_cursor_value(row) + if row_value is None: + if self.on_cursor_value_missing == "exclude": + return None, False, False + else: + return row, False, False + last_value = self.last_value last_value_func = self.last_value_func @@ -299,6 +326,7 @@ def __call__( # TODO: Json path support. For now assume the cursor_path is a column name cursor_path = self.cursor_path + # The new max/min value try: # NOTE: datetimes are always pendulum in UTC @@ -310,11 +338,16 @@ def __call__( self.resource_name, cursor_path, tbl, - f"Column name {cursor_path} was not found in the arrow table. Not nested JSON paths" + f"Column name `{cursor_path}` was not found in the arrow table. Nested JSON paths" " are not supported for arrow tables and dataframes, the incremental cursor_path" " must be a column name.", ) from e + if tbl.schema.field(cursor_path).nullable: + tbl_without_null, tbl_with_null = self._process_null_at_cursor_path(tbl) + + tbl = tbl_without_null + # If end_value is provided, filter to include table rows that are "less" than end_value if self.end_value is not None: try: @@ -396,12 +429,28 @@ def __call__( ) ) + # drop the temp unique index before concat and returning + if "_dlt_index" in tbl.schema.names: + tbl = pyarrow.remove_columns(tbl, ["_dlt_index"]) + + if self.on_cursor_value_missing == "include": + if isinstance(tbl, pa.RecordBatch): + assert isinstance(tbl_with_null, pa.RecordBatch) + tbl = pa.Table.from_batches([tbl, tbl_with_null]) + else: + tbl = pa.concat_tables([tbl, tbl_with_null]) + if len(tbl) == 0: return None, start_out_of_range, end_out_of_range - try: - tbl = pyarrow.remove_columns(tbl, ["_dlt_index"]) - except KeyError: - pass if is_pandas: - return tbl.to_pandas(), start_out_of_range, end_out_of_range + tbl = tbl.to_pandas() return tbl, start_out_of_range, end_out_of_range + + def _process_null_at_cursor_path(self, tbl: "pa.Table") -> Tuple["pa.Table", "pa.Table"]: + mask = pa.compute.is_valid(tbl[self.cursor_path]) + rows_without_null = tbl.filter(mask) + rows_with_null = tbl.filter(pa.compute.invert(mask)) + if self.on_cursor_value_missing == "raise": + if rows_with_null.num_rows > 0: + raise IncrementalCursorPathHasValueNone(self.resource_name, self.cursor_path) + return rows_without_null, rows_with_null diff --git a/dlt/extract/incremental/typing.py b/dlt/extract/incremental/typing.py index 9cec97d34d..a5e2612db4 100644 --- a/dlt/extract/incremental/typing.py +++ b/dlt/extract/incremental/typing.py @@ -1,8 +1,9 @@ -from typing import TypedDict, Optional, Any, List, TypeVar, Callable, Sequence +from typing import TypedDict, Optional, Any, List, Literal, TypeVar, Callable, Sequence TCursorValue = TypeVar("TCursorValue", bound=Any) LastValueFunc = Callable[[Sequence[TCursorValue]], Any] +OnCursorValueMissing = Literal["raise", "include", "exclude"] class IncrementalColumnState(TypedDict): diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index 68fc46e6dc..5ff587f20e 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -689,7 +689,7 @@ than `end_value`. :::caution In rare cases when you use Incremental with a transformer, `dlt` will not be able to automatically close -generator associated with a row that is out of range. You can still use still call `can_close()` method on +generator associated with a row that is out of range. You can still call the `can_close()` method on incremental and exit yield loop when true. ::: @@ -907,22 +907,75 @@ Consider the example below for reading incremental loading parameters from "conf ``` `id_after` incrementally stores the latest `cursor_path` value for future pipeline runs. -### Loading NULL values in the incremental cursor field +### Loading when incremental cursor path is missing or value is None/NULL -When loading incrementally with a cursor field, each row is expected to contain a value at the cursor field that is not `None`. -For example, the following source data will raise an error: +You can customize the incremental processing of dlt by setting the parameter `on_cursor_value_missing`. + +When loading incrementally with the default settings, there are two assumptions: +1. each row contains the cursor path +2. each row is expected to contain a value at the cursor path that is not `None`. + +For example, the two following source data will raise an error: ```py @dlt.resource -def some_data(updated_at=dlt.sources.incremental("updated_at")): +def some_data_without_cursor_path(updated_at=dlt.sources.incremental("updated_at")): yield [ {"id": 1, "created_at": 1, "updated_at": 1}, - {"id": 2, "created_at": 2, "updated_at": 2}, + {"id": 2, "created_at": 2}, # cursor field is missing + ] + +list(some_data_without_cursor_path()) + +@dlt.resource +def some_data_without_cursor_value(updated_at=dlt.sources.incremental("updated_at")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 3, "created_at": 4, "updated_at": None}, # value at cursor field is None + ] + +list(some_data_without_cursor_value()) +``` + + +To process a data set where some records do not include the incremental cursor path or where the values at the cursor path are `None,` there are the following four options: + +1. Configure the incremental load to raise an exception in case there is a row where the cursor path is missing or has the value `None` using `incremental(..., on_cursor_value_missing="raise")`. This is the default behavior. +2. Configure the incremental load to tolerate the missing cursor path and `None` values using `incremental(..., on_cursor_value_missing="include")`. +3. Configure the incremental load to exclude the missing cursor path and `None` values using `incremental(..., on_cursor_value_missing="exclude")`. +4. Before the incremental processing begins: Ensure that the incremental field is present and transform the values at the incremental cursor to a value different from `None`. [See docs below](#transform-records-before-incremental-processing) + +Here is an example of including rows where the incremental cursor value is missing or `None`: +```py +@dlt.resource +def some_data(updated_at=dlt.sources.incremental("updated_at", on_cursor_value_missing="include")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 2}, + {"id": 3, "created_at": 4, "updated_at": None}, + ] + +result = list(some_data()) +assert len(result) == 3 +assert result[1] == {"id": 2, "created_at": 2} +assert result[2] == {"id": 3, "created_at": 4, "updated_at": None} +``` + +If you do not want to import records without the cursor path or where the value at the cursor path is `None` use the following incremental configuration: + +```py +@dlt.resource +def some_data(updated_at=dlt.sources.incremental("updated_at", on_cursor_value_missing="exclude")): + yield [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 2}, {"id": 3, "created_at": 4, "updated_at": None}, ] -list(some_data()) +result = list(some_data()) +assert len(result) == 1 ``` +### Transform records before incremental processing If you want to load data that includes `None` values you can transform the records before the incremental processing. You can add steps to the pipeline that [filter, transform, or pivot your data](../general-usage/resource.md#filter-transform-and-pivot-data). @@ -1162,4 +1215,4 @@ sources: } ``` -Verify that the `last_value` is updated between pipeline runs. \ No newline at end of file +Verify that the `last_value` is updated between pipeline runs. diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index c401552fb2..a9867aa54b 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -33,6 +33,7 @@ IncrementalCursorInvalidCoercion, IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, + IncrementalCursorPathHasValueNone, ) from dlt.pipeline.exceptions import PipelineStepFailed @@ -44,6 +45,10 @@ ALL_TEST_DATA_ITEM_FORMATS, ) +from tests.pipeline.utils import assert_query_data + +import pyarrow as pa + @pytest.fixture(autouse=True) def switch_to_fifo(): @@ -167,8 +172,9 @@ def some_data(created_at=dlt.sources.incremental("created_at")): p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - p.extract(some_data()) + assert values == [None] + p.extract(some_data()) assert values == [None, 5] @@ -635,6 +641,458 @@ def some_data(last_timestamp=dlt.sources.incremental("item.timestamp")): assert pip_ex.value.__context__.json_path == "item.timestamp" +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_and_updates_incremental_cursor_1( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": None}, + {"id": 2, "created_at": 1}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(id) from some_data", [3]) + assert_query_data(p, "select count(created_at) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_does_not_include_overlapping_records( + item_type: TestDataItemFormat, +) -> None: + @dlt.resource + def some_data( + invocation: int, + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include"), + ): + if invocation == 1: + yield data_to_item_format( + item_type, + [ + {"id": 1, "created_at": None}, + {"id": 2, "created_at": 1}, + {"id": 3, "created_at": 2}, + ], + ) + elif invocation == 2: + yield data_to_item_format( + item_type, + [ + {"id": 4, "created_at": 1}, + {"id": 5, "created_at": None}, + {"id": 6, "created_at": 3}, + ], + ) + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(1), destination="duckdb") + p.run(some_data(2), destination="duckdb") + + assert_query_data(p, "select id from some_data order by id", [1, 2, 3, 5, 6]) + assert_query_data( + p, "select created_at from some_data order by created_at", [1, 2, 3, None, None] + ) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 3 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_and_updates_incremental_cursor_2( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": None}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(id) from some_data", [3]) + assert_query_data(p, "select count(created_at) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_and_updates_incremental_cursor_3( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": 2}, + {"id": 3, "created_at": None}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + assert_query_data(p, "select count(id) from some_data", [3]) + assert_query_data(p, "select count(created_at) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_includes_records_without_cursor_path( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="include") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + assert_query_data(p, "select count(id) from some_data", [2]) + assert_query_data(p, "select count(created_at) from some_data", [1]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 1 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_excludes_records_and_updates_incremental_cursor( + item_type: TestDataItemFormat, +) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": 2}, + {"id": 3, "created_at": None}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="exclude") + ): + yield source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + assert_query_data(p, "select count(id) from some_data", [2]) + + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "created_at" + ] + assert s["last_value"] == 2 + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_can_raise_on_none_1(item_type: TestDataItemFormat) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2, "created_at": None}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="raise") + ): + yield source_items + + with pytest.raises(IncrementalCursorPathHasValueNone) as py_ex: + list(some_data()) + assert py_ex.value.json_path == "created_at" + + # same thing when run in pipeline + with pytest.raises(PipelineStepFailed) as pip_ex: + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data()) + + assert isinstance(pip_ex.value.__context__, IncrementalCursorPathHasValueNone) + assert pip_ex.value.__context__.json_path == "created_at" + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_path_none_can_raise_on_none_2(item_type: TestDataItemFormat) -> None: + data = [ + {"id": 1, "created_at": 1}, + {"id": 2}, + {"id": 3, "created_at": 2}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="raise") + ): + yield source_items + + # there is no fixed, error because cursor path is missing + if item_type == "object": + with pytest.raises(IncrementalCursorPathMissing) as ex: + list(some_data()) + assert ex.value.json_path == "created_at" + # there is a fixed schema, error because value is null + else: + with pytest.raises(IncrementalCursorPathHasValueNone) as e: + list(some_data()) + assert e.value.json_path == "created_at" + + # same thing when run in pipeline + with pytest.raises(PipelineStepFailed) as e: # type: ignore[assignment] + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data()) + if item_type == "object": + assert isinstance(e.value.__context__, IncrementalCursorPathMissing) + else: + assert isinstance(e.value.__context__, IncrementalCursorPathHasValueNone) + assert e.value.__context__.json_path == "created_at" # type: ignore[attr-defined] + + +@pytest.mark.parametrize("item_type", ["arrow-table", "arrow-batch", "pandas"]) +def test_cursor_path_none_can_raise_on_column_missing(item_type: TestDataItemFormat) -> None: + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data( + created_at=dlt.sources.incremental("created_at", on_cursor_value_missing="raise") + ): + yield source_items + + with pytest.raises(IncrementalCursorPathMissing) as py_ex: + list(some_data()) + assert py_ex.value.json_path == "created_at" + + # same thing when run in pipeline + with pytest.raises(PipelineStepFailed) as pip_ex: + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data()) + assert pip_ex.value.__context__.json_path == "created_at" # type: ignore[attr-defined] + assert isinstance(pip_ex.value.__context__, IncrementalCursorPathMissing) + + +def test_cursor_path_none_nested_can_raise_on_none_1() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[0].created_at", on_cursor_value_missing="raise" + ) + ): + yield {"data": {"items": [{"created_at": None}, {"created_at": 1}]}} + + with pytest.raises(IncrementalCursorPathHasValueNone) as e: + list(some_data()) + assert e.value.json_path == "data.items[0].created_at" + + +def test_cursor_path_none_nested_can_raise_on_none_2() -> None: + # No pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[*].created_at", on_cursor_value_missing="raise" + ) + ): + yield {"data": {"items": [{"created_at": None}, {"created_at": 1}]}} + + with pytest.raises(IncrementalCursorPathHasValueNone) as e: + list(some_data()) + assert e.value.json_path == "data.items[*].created_at" + + +def test_cursor_path_none_nested_can_include_on_none_1() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[*].created_at", on_cursor_value_missing="include" + ) + ): + yield { + "data": { + "items": [ + {"created_at": None}, + {"created_at": 1}, + ] + } + } + + results = list(some_data()) + assert results[0]["data"]["items"] == [ + {"created_at": None}, + {"created_at": 1}, + ] + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(*) from some_data__data__items", [2]) + + +def test_cursor_path_none_nested_can_include_on_none_2() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[0].created_at", on_cursor_value_missing="include" + ) + ): + yield { + "data": { + "items": [ + {"created_at": None}, + {"created_at": 1}, + ] + } + } + + results = list(some_data()) + assert results[0]["data"]["items"] == [ + {"created_at": None}, + {"created_at": 1}, + ] + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(*) from some_data__data__items", [2]) + + +def test_cursor_path_none_nested_includes_rows_without_cursor_path() -> None: + # No nested json path support for pandas and arrow. See test_nested_cursor_path_arrow_fails + @dlt.resource + def some_data( + created_at=dlt.sources.incremental( + "data.items[*].created_at", on_cursor_value_missing="include" + ) + ): + yield { + "data": { + "items": [ + {"id": 1}, + {"id": 2, "created_at": 2}, + ] + } + } + + results = list(some_data()) + assert results[0]["data"]["items"] == [ + {"id": 1}, + {"id": 2, "created_at": 2}, + ] + + p = dlt.pipeline(pipeline_name=uniq_id()) + p.run(some_data(), destination="duckdb") + + assert_query_data(p, "select count(*) from some_data__data__items", [2]) + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_set_default_value_for_incremental_cursor(item_type: TestDataItemFormat) -> None: + @dlt.resource + def some_data(created_at=dlt.sources.incremental("updated_at")): + yield data_to_item_format( + item_type, + [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 4, "updated_at": None}, + {"id": 3, "created_at": 3, "updated_at": 3}, + ], + ) + + def set_default_updated_at(record): + if record.get("updated_at") is None: + record["updated_at"] = record.get("created_at", pendulum.now().int_timestamp) + return record + + def set_default_updated_at_pandas(df): + df["updated_at"] = df["updated_at"].fillna(df["created_at"]) + return df + + def set_default_updated_at_arrow(records): + updated_at_is_null = pa.compute.is_null(records.column("updated_at")) + updated_at_filled = pa.compute.if_else( + updated_at_is_null, records.column("created_at"), records.column("updated_at") + ) + if item_type == "arrow-table": + records = records.set_column( + records.schema.get_field_index("updated_at"), + pa.field("updated_at", records.column("updated_at").type), + updated_at_filled, + ) + elif item_type == "arrow-batch": + columns = [records.column(i) for i in range(records.num_columns)] + columns[2] = updated_at_filled + records = pa.RecordBatch.from_arrays(columns, schema=records.schema) + return records + + if item_type == "object": + func = set_default_updated_at + elif item_type == "pandas": + func = set_default_updated_at_pandas + elif item_type in ["arrow-table", "arrow-batch"]: + func = set_default_updated_at_arrow + + result = list(some_data().add_map(func, insert_at=1)) + values = data_item_to_list(item_type, result) + assert data_item_length(values) == 3 + assert values[1]["updated_at"] == 4 + + # same for pipeline run + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data().add_map(func, insert_at=1)) + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "updated_at" + ] + assert s["last_value"] == 4 + + def test_json_path_cursor() -> None: @dlt.resource def some_data(last_timestamp=dlt.sources.incremental("item.timestamp|modifiedAt")): diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index c757959bec..af3a6c239e 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -246,7 +246,6 @@ class TestRow(BaseModel): example_string: str # yield model in resource so incremental fails when looking for "id" - # TODO: support pydantic models in incremental @dlt.resource(name="table_name", primary_key="id", write_disposition="replace") def generate_rows_incremental( From 36c0d140ba7c94807b498a0d3f77182e6c24354e Mon Sep 17 00:00:00 2001 From: novica Date: Mon, 2 Sep 2024 09:27:10 +0200 Subject: [PATCH 3/3] fix installation command" (#1741) --- docs/examples/postgres_to_postgres/postgres_to_postgres.py | 2 +- docs/website/blog/2024-01-10-dlt-mode.md | 2 +- docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/examples/postgres_to_postgres/postgres_to_postgres.py b/docs/examples/postgres_to_postgres/postgres_to_postgres.py index c6502f236a..3e88cb7ee8 100644 --- a/docs/examples/postgres_to_postgres/postgres_to_postgres.py +++ b/docs/examples/postgres_to_postgres/postgres_to_postgres.py @@ -33,7 +33,7 @@ Install `dlt` with `duckdb` as extra, also `connectorx`, Postgres adapter and progress bar tool: ```sh -pip install dlt[duckdb] connectorx pyarrow psycopg2-binary alive-progress +pip install "dlt[duckdb]" connectorx pyarrow psycopg2-binary alive-progress ``` Run the example: diff --git a/docs/website/blog/2024-01-10-dlt-mode.md b/docs/website/blog/2024-01-10-dlt-mode.md index 1d6bf8ca0e..232124df45 100644 --- a/docs/website/blog/2024-01-10-dlt-mode.md +++ b/docs/website/blog/2024-01-10-dlt-mode.md @@ -124,7 +124,7 @@ With the model we just created, called Products, a chart can be instantly create In this demo, we’ll forego the authentication issues of connecting to a data warehouse, and choose the DuckDB destination to show how the Python environment within Mode can be used to initialize a data pipeline and dump normalized data into a destination. In order to see how it works, we first install dlt[duckdb] into the Python environment. ```sh -!pip install dlt[duckdb] +!pip install "dlt[duckdb]" ``` Next, we initialize the dlt pipeline: diff --git a/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md b/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md index 0e342a3fea..41ba5926c4 100644 --- a/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md +++ b/docs/website/docs/walkthroughs/dispatch-to-multiple-tables.md @@ -12,7 +12,7 @@ We'll use the [GitHub API](https://docs.github.com/en/rest) to fetch the events 1. Install dlt with duckdb support: ```sh -pip install dlt[duckdb] +pip install "dlt[duckdb]" ``` 2. Create a new a new file `github_events_dispatch.py` and paste the following code: