Skip to content

Commit

Permalink
Drop cmd use package state, refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Apr 17, 2024
1 parent ebe4def commit 7ebced2
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 120 deletions.
11 changes: 11 additions & 0 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,17 @@ def update_schema(self, schema: "Schema") -> None:
self._settings = deepcopy(schema.settings)
self._compile_settings()

def drop_tables(
self, table_names: Sequence[str], seen_data_only: bool = False
) -> List[TTableSchema]:
"""Drops tables from the schema and returns the dropped tables"""
result = []
for table_name in table_names:
table = self.tables.get(table_name)
if table and (not seen_data_only or utils.has_table_seen_data(table)):
result.append(self._schema_tables.pop(table_name))
return result

def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: StrAny) -> StrAny:
rv_row: DictStrAny = {}
column_prop: TColumnProp = utils.hint_to_column_prop(hint_type)
Expand Down
33 changes: 3 additions & 30 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
SupportsPipeline,
WithStepInfo,
reset_resource_state,
TRefreshMode,
pipeline_state,
)
from dlt.common.runtime import signals
from dlt.common.runtime.collector import Collector, NULL_COLLECTOR
Expand Down Expand Up @@ -46,7 +44,6 @@
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor
from dlt.extract.utils import get_data_item_format
from dlt.pipeline.drop import drop_resources
from dlt.common.pipeline import TRefreshMode


def data_to_sources(
Expand Down Expand Up @@ -177,14 +174,12 @@ def __init__(
normalize_storage_config: NormalizeStorageConfiguration,
collector: Collector = NULL_COLLECTOR,
original_data: Any = None,
refresh: Optional[TRefreshMode] = None,
) -> None:
"""optionally saves originally extracted `original_data` to generate extract info"""
self.collector = collector
self.schema_storage = schema_storage
self.extract_storage = ExtractStorage(normalize_storage_config)
self.original_data: Any = original_data
self.refresh = refresh
super().__init__()

def _compute_metrics(self, load_id: str, source: DltSource) -> ExtractMetrics:
Expand Down Expand Up @@ -372,6 +367,7 @@ def extract(
source: DltSource,
max_parallel_items: int,
workers: int,
load_package_state_update: Optional[Dict[str, Any]] = None,
) -> str:
# generate load package to be able to commit all the sources together later
load_id = self.extract_storage.create_load_package(source.discover_schema())
Expand All @@ -391,31 +387,8 @@ def extract(
source_state_key=source.name,
)
):
if self.refresh is not None:
_resources_to_drop = (
list(source.resources.extracted) if self.refresh != "drop_dataset" else []
)
_state, _ = pipeline_state(Container())
new_schema, new_state, drop_info = drop_resources(
source.schema,
_state,
resources=_resources_to_drop,
drop_all=self.refresh == "drop_dataset",
state_paths="*" if self.refresh == "drop_dataset" else [],
)
_state.update(new_state)
if drop_info["tables"]:
drop_tables = [
table
for table in source.schema.tables.values()
if table["name"] in drop_info["tables"]
]
if self.refresh == "drop_data":
load_package.state["truncated_tables"] = drop_tables
else:
source.schema.tables.clear()
source.schema.tables.update(new_schema.tables)
load_package.state["dropped_tables"] = drop_tables
if load_package_state_update:
load_package.state.update(load_package_state_update) # type: ignore[typeddict-item]

# reset resource states, the `extracted` list contains all the explicit resources and all their parents
for resource in source.resources.extracted.values():
Expand Down
1 change: 0 additions & 1 deletion dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:

dropped_tables = current_load_package()["state"].get("dropped_tables", [])
truncated_tables = current_load_package()["state"].get("truncated_tables", [])

# initialize analytical storage ie. create dataset required by passed schema
with self.get_destination_client(schema) as job_client:
if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None:
Expand Down
1 change: 0 additions & 1 deletion dlt/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
JobClientBase,
WithStagingDataset,
)
from dlt.common.pipeline import TRefreshMode


def get_completed_table_chain(
Expand Down
20 changes: 13 additions & 7 deletions dlt/pipeline/drop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union, Iterable, Optional, List, Dict, Any, Tuple, TypedDict
from copy import deepcopy
from itertools import chain
from dataclasses import dataclass

from dlt.common.schema import Schema
from dlt.common.pipeline import (
Expand All @@ -10,7 +11,7 @@
reset_resource_state,
_delete_source_state_keys,
)
from dlt.common.schema.typing import TSimpleRegex
from dlt.common.schema.typing import TSimpleRegex, TTableSchema
from dlt.common.schema.utils import (
group_tables_by_resource,
compile_simple_regexes,
Expand All @@ -32,6 +33,14 @@ class _DropInfo(TypedDict):
warnings: List[str]


@dataclass
class _DropResult:
schema: Schema
state: TPipelineState
info: _DropInfo
dropped_tables: List[TTableSchema]


def _create_modified_state(
state: TPipelineState,
resource_pattern: Optional[REPattern],
Expand Down Expand Up @@ -68,7 +77,7 @@ def drop_resources(
state_paths: jsonpath.TAnyJsonPath = (),
drop_all: bool = False,
state_only: bool = False,
) -> Tuple[Schema, TPipelineState, _DropInfo]:
) -> _DropResult:
"""Generate a new schema and pipeline state with the requested resources removed.
Args:
Expand Down Expand Up @@ -140,8 +149,5 @@ def drop_resources(
f" {list(group_tables_by_resource(data_tables).keys())}"
)

for tbl in tables_to_drop:
del schema.tables[tbl["name"]]
schema._bump_version() # TODO: needed?

return schema, new_state, info
dropped_tables = schema.drop_tables([t["name"] for t in tables_to_drop], seen_data_only=True)
return _DropResult(schema, new_state, info, dropped_tables)
74 changes: 22 additions & 52 deletions dlt/pipeline/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
from copy import deepcopy
from typing import (
Callable,
Sequence,
Expand Down Expand Up @@ -102,11 +103,10 @@ def __init__(

if not pipeline.default_schema_name:
raise PipelineNeverRan(pipeline.pipeline_name, pipeline.pipelines_dir)
self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name]
self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name].clone()

self.drop_tables = not state_only

self._drop_schema, self._new_state, self.info = drop_resources(
drop_result = drop_resources(
# self._drop_schema, self._new_state, self.info = drop_resources(
self.schema,
pipeline.state,
resources,
Expand All @@ -115,6 +115,12 @@ def __init__(
state_only,
)

self._new_state = drop_result.state
self.info = drop_result.info
self._new_schema = drop_result.schema
self._dropped_tables = drop_result.dropped_tables
self.drop_tables = not state_only and bool(self._dropped_tables)

self.drop_state = bool(drop_all or resources or state_paths)

@property
Expand All @@ -125,46 +131,6 @@ def is_empty(self) -> bool:
and len(self.info["resource_states"]) == 0
)

def _drop_destination_tables(self, allow_schema_tables: bool = False) -> None:
table_names = self.info["tables"]
if not allow_schema_tables:
for table_name in table_names:
assert table_name not in self.schema._schema_tables, (
f"You are dropping table {table_name} in {self.schema.name} but it is still"
" present in the schema"
)
with self.pipeline._sql_job_client(self.schema) as client:
client.drop_tables(*table_names, replace_schema=True)
# also delete staging but ignore if staging does not exist
if isinstance(client, WithStagingDataset):
with contextlib.suppress(DatabaseUndefinedRelation):
with client.with_staging_dataset():
client.drop_tables(*table_names, replace_schema=True)

def _delete_schema_tables(self) -> None:
for tbl in self.info["tables"]:
del self.schema.tables[tbl]
# bump schema, we'll save later
self.schema._bump_version()

def _extract_state(self) -> None:
state: Dict[str, Any]
with self.pipeline.managed_state(extract_state=True) as state: # type: ignore[assignment]
state.clear()
state.update(self._new_state)
try:
# Also update the state in current context if one is active
# so that we can run the pipeline directly after drop in the same process
ctx = Container()[StateInjectableContext]
state = ctx.state # type: ignore[assignment]
state.clear()
state.update(self._new_state)
except ContextDefaultCannotBeCreated:
pass

def _save_local_schema(self) -> None:
self.pipeline.schemas.save_schema(self.schema)

def __call__(self) -> None:
if (
self.pipeline.has_pending_data
Expand All @@ -177,14 +143,16 @@ def __call__(self) -> None:
if not self.drop_state and not self.drop_tables:
return # Nothing to drop

if self.drop_tables:
self._delete_schema_tables()
self._drop_destination_tables()
if self.drop_tables:
self._save_local_schema()
if self.drop_state:
self._extract_state()
# Send updated state to destination
self._new_schema._bump_version()
new_state = deepcopy(self._new_state)
force_state_extract(new_state)

self.pipeline._save_and_extract_state_and_schema(
new_state,
schema=self._new_schema,
load_package_state_update={"dropped_tables": self._dropped_tables},
)

self.pipeline.normalize()
try:
self.pipeline.load(raise_on_failed_jobs=True)
Expand All @@ -193,6 +161,8 @@ def __call__(self) -> None:
self.pipeline.drop_pending_packages()
with self.pipeline.managed_state() as state:
force_state_extract(state)
# Restore original schema file so all tables are known on next run
self.pipeline.schemas.save_schema(self.schema)
raise


Expand Down
Loading

0 comments on commit 7ebced2

Please sign in to comment.