Skip to content

Commit

Permalink
Escape partitioning column names
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed May 24, 2024
1 parent ef25591 commit 55ec9d9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 25 deletions.
14 changes: 11 additions & 3 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Callable,
Iterable,
Type,
cast,
)
from copy import deepcopy
import re
Expand Down Expand Up @@ -404,10 +405,15 @@ def _from_db_type(
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"

def _iceberg_partition_clause(self, partition_hints: Optional[List[str]]) -> str:
def _iceberg_partition_clause(self, partition_hints: Optional[Dict[str, str]]) -> str:
if not partition_hints:
return ""
return f"PARTITIONED BY ({', '.join(partition_hints)})"
formatted_strings = []
for column_name, template in partition_hints.items():
formatted_strings.append(
template.format(column_name=self.sql_client.escape_ddl_identifier(column_name))
)
return f"PARTITIONED BY ({', '.join(formatted_strings)})"

def _get_table_update_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
Expand Down Expand Up @@ -435,7 +441,9 @@ def _get_table_update_sql(
sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""")
else:
if is_iceberg:
partition_clause = self._iceberg_partition_clause(table.get(PARTITION_HINT))
partition_clause = self._iceberg_partition_clause(
cast(Optional[Dict[str, str]], table.get(PARTITION_HINT))
)
sql.append(
f"""CREATE TABLE {qualified_table_name}
({columns})
Expand Down
65 changes: 46 additions & 19 deletions dlt/destinations/impl/athena/athena_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,60 @@
PARTITION_HINT: Final[str] = "x-athena-partition"


class PartitionTransformation:
template: str
"""Template string of the transformation including column name placeholder. E.g. `bucket(16, {column_name})`"""
column_name: str
"""Column name to apply the transformation to"""

def __init__(self, template: str, column_name: str) -> None:
self.template = template
self.column_name = column_name


class athena_partition:
"""Helper class to generate iceberg partition transform strings.
"""Helper class to generate iceberg partition transformations
E.g. `athena_partition.bucket(16, "id")` will return `bucket(16, "id")`.
E.g. `athena_partition.bucket(16, "id")` will return a transformation with template `bucket(16, {column_name})`
This can be correctly rendered by the athena loader with escaped column name.
"""

@staticmethod
def year(column_name: str) -> str:
def year(column_name: str) -> PartitionTransformation:
"""Partition by year part of a date or timestamp column."""
return f"year({column_name})"
return PartitionTransformation("year({column_name})", column_name)

@staticmethod
def month(column_name: str) -> str:
def month(column_name: str) -> PartitionTransformation:
"""Partition by month part of a date or timestamp column."""
return f"month({column_name})"
return PartitionTransformation("month({column_name})", column_name)

@staticmethod
def day(column_name: str) -> str:
def day(column_name: str) -> PartitionTransformation:
"""Partition by day part of a date or timestamp column."""
return f"day({column_name})"
return PartitionTransformation("day({column_name})", column_name)

@staticmethod
def hour(column_name: str) -> str:
def hour(column_name: str) -> PartitionTransformation:
"""Partition by hour part of a date or timestamp column."""
return f"hour({column_name})"
return PartitionTransformation("hour({column_name})", column_name)

@staticmethod
def bucket(n: int, column_name: str) -> str:
def bucket(n: int, column_name: str) -> PartitionTransformation:
"""Partition by hashed value to n buckets."""
return f"bucket({n}, {column_name})"
return PartitionTransformation(f"bucket({n}, {{column_name}})", column_name)

@staticmethod
def truncate(length: int, column_name: str) -> str:
def truncate(length: int, column_name: str) -> PartitionTransformation:
"""Partition by value truncated to length."""
return f"truncate({length}, {column_name})"
return PartitionTransformation(f"truncate({length}, {{column_name}})", column_name)


def athena_adapter(
data: Any,
partition: Union[str, Sequence[str]] = None,
partition: Union[
str, PartitionTransformation, Sequence[Union[str, PartitionTransformation]]
] = None,
) -> DltResource:
"""
Prepares data for loading into Athena
Expand All @@ -60,7 +74,9 @@ def athena_adapter(
data: The data to be transformed.
This can be raw data or an instance of DltResource.
If raw data is provided, the function will wrap it into a `DltResource` object.
partition: Column name(s) partition transform string(s) to partition table by
partition: Column name(s) or instances of `PartitionTransformation` to partition the table by.
To use a transformation it's best to use the methods of the helper class `athena_partition`
to generate correctly escaped SQL in the loader.
Returns:
A `DltResource` object that is ready to be loaded into BigQuery.
Expand All @@ -77,11 +93,22 @@ def athena_adapter(
additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {}

if partition:
if isinstance(partition, str):
if isinstance(partition, str) or not isinstance(partition, Sequence):
partition = [partition]

# Note: PARTITIONED BY clause identifiers are not allowed to be quoted. They are added as-is.
additional_table_hints[PARTITION_HINT] = list(partition)
# Partition hint is `{column_name: template}`, e.g. `{"department": "{column_name}", "date_hired": "year({column_name})"}`
# Use one dict for all hints instead of storing on column so order is preserved
partition_hint: Dict[str, str] = {}

for item in partition:
if isinstance(item, PartitionTransformation):
# Client will generate the final SQL string with escaped column name injected
partition_hint[item.column_name] = item.template
else:
# Item is the column name
partition_hint[item] = "{column_name}"

additional_table_hints[PARTITION_HINT] = partition_hint

if additional_table_hints:
resource.apply_hints(additional_table_hints=additional_table_hints)
Expand Down
4 changes: 1 addition & 3 deletions tests/load/athena_iceberg/test_athena_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def not_partitioned_table():
)[0]

# Partition clause is generated with original order
expected_clause = (
"PARTITIONED BY (category, month(created_at), bucket(10, product_id), truncate(2, name))"
)
expected_clause = "PARTITIONED BY (`category`, month(`created_at`), bucket(10, `product_id`), truncate(2, `name`))"
assert expected_clause in sql_partitioned

# No partition clause otherwise
Expand Down

0 comments on commit 55ec9d9

Please sign in to comment.