Skip to content

Commit

Permalink
refactors column and constraint sql in job client
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Dec 16, 2024
1 parent a6fcb85 commit f2bac96
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 60 deletions.
6 changes: 0 additions & 6 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,6 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def _get_storage_table_query_columns(self) -> List[str]:
fields = super()._get_storage_table_query_columns()
fields[2] = ( # Override because this is the only way to get data type with precision
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/dremio/dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,6 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def _create_merge_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
) -> List[FollowupJobRequest]:
Expand Down
11 changes: 0 additions & 11 deletions dlt/destinations/impl/duckdb/duck.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,6 @@ def create_load_job(
job = DuckDbCopyJob(file_path)
return job

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def _from_db_type(
self, pq_t: str, precision: Optional[int], scale: Optional[int]
) -> TColumnType:
Expand Down
6 changes: 1 addition & 5 deletions dlt/destinations/impl/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = Non
else:
db_type = self.type_mapper.to_destination_type(c, table)

hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
hints_str = self._get_column_hints_sql(c)
column_name = self.sql_client.escape_column_name(c["name"])
return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}"

Expand Down
12 changes: 0 additions & 12 deletions dlt/destinations/impl/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,6 @@ def create_load_job(
job = PostgresCsvCopyJob(file_path)
return job

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_ = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
nullability = self._gen_not_null(c.get("nullable", True))
column_type = self.type_mapper.to_destination_type(c, table)

return f"{column_name} {column_type} {hints_} {nullability}"

def _create_replace_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
) -> List[FollowupJobRequest]:
Expand Down
12 changes: 1 addition & 11 deletions dlt/destinations/impl/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
capabilities,
)
super().__init__(schema, config, sql_client)
self.active_hints = HINT_TO_REDSHIFT_ATTR
self.sql_client = sql_client
self.config: RedshiftClientConfiguration = config
self.type_mapper = self.capabilities.get_type_mapper()
Expand All @@ -162,17 +163,6 @@ def _create_merge_followup_jobs(
) -> List[FollowupJobRequest]:
return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)]

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_str = " ".join(
HINT_TO_REDSHIFT_ATTR.get(h, "")
for h in HINT_TO_REDSHIFT_ATTR.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
) -> LoadJob:
Expand Down
36 changes: 27 additions & 9 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,26 @@
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Iterable,
Iterator,
Generator,
)
import zlib
import re
from contextlib import contextmanager
from contextlib import suppress

from dlt.common import pendulum, logger
from dlt.common.destination.capabilities import DataTypeMapper
from dlt.common.json import json
from dlt.common.schema.typing import (
C_DLT_LOAD_ID,
COLUMN_HINTS,
TColumnType,
TColumnSchemaBase,
TTableFormat,
)
from dlt.common.schema.utils import (
get_inherited_table_hint,
Expand All @@ -40,11 +38,11 @@
from dlt.common.storages import FileStorage
from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables
from dlt.common.schema import TColumnHint
from dlt.common.destination.reference import (
PreparedTableSchema,
StateInfo,
StorageSchemaInfo,
SupportsReadableDataset,
WithStateSync,
DestinationClientConfiguration,
DestinationClientDwhConfiguration,
Expand All @@ -55,9 +53,7 @@
JobClientBase,
HasFollowupJobs,
CredentialsConfiguration,
SupportsReadableRelation,
)
from dlt.destinations.dataset import ReadableDBAPIDataset

from dlt.destinations.exceptions import DatabaseUndefinedRelation
from dlt.destinations.job_impl import (
Expand Down Expand Up @@ -154,6 +150,8 @@ def __init__(
self.state_table_columns = ", ".join(
sql_client.escape_column_name(col) for col in state_table_["columns"]
)
self.active_hints: Dict[TColumnHint, str] = {}
self.type_mapper: DataTypeMapper = None
super().__init__(schema, config, sql_client.capabilities)
self.sql_client = sql_client
assert isinstance(config, DestinationClientDwhConfiguration)
Expand Down Expand Up @@ -569,6 +567,7 @@ def _get_table_update_sql(
# build CREATE
sql = self._make_create_table(qualified_name, table) + " (\n"
sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns])
sql += self._get_constraints_sql(table_name, new_columns, generate_alter)
sql += ")"
sql_result.append(sql)
else:
Expand All @@ -582,8 +581,16 @@ def _get_table_update_sql(
sql_result.extend(
[sql_base + col_statement for col_statement in add_column_statements]
)
constraints_sql = self._get_constraints_sql(table_name, new_columns, generate_alter)
if constraints_sql:
sql_result.append(constraints_sql)
return sql_result

def _get_constraints_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> str:
return ""

def _check_table_update_hints(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> None:
Expand Down Expand Up @@ -613,9 +620,20 @@ def _check_table_update_hints(
" existing tables."
)

@abstractmethod
def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
pass
hints_ = self._get_column_hints_sql(c)
column_name = self.sql_client.escape_column_name(c["name"])
nullability = self._gen_not_null(c.get("nullable", True))
column_type = self.type_mapper.to_destination_type(c, table)

return f"{column_name} {column_type} {hints_} {nullability}"

def _get_column_hints_sql(self, c: TColumnSchema) -> str:
return " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True # use ColumnPropInfos to get default value
)

@staticmethod
def _gen_not_null(nullable: bool) -> str:
Expand Down

0 comments on commit f2bac96

Please sign in to comment.