diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index cfc1ac17d3..0281ce31ef 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -1,4 +1,4 @@ -from typing import Sequence, cast, overload +from typing import Sequence, cast, overload, Optional from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnSchema, TWriteDisposition, TSchemaContract @@ -14,6 +14,7 @@ from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR from dlt.pipeline.warnings import credentials_argument_deprecated, full_refresh_argument_deprecated +from dlt.pipeline.typing import TRefreshMode @overload @@ -28,6 +29,7 @@ def pipeline( export_schema_path: str = None, full_refresh: bool = False, dev_mode: bool = False, + refresh: Optional[TRefreshMode] = None, credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, ) -> Pipeline: @@ -97,6 +99,7 @@ def pipeline( export_schema_path: str = None, full_refresh: bool = False, dev_mode: bool = False, + refresh: Optional[TRefreshMode] = None, credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, **kwargs: Any, @@ -147,6 +150,7 @@ def pipeline( False, last_config(**kwargs), kwargs["runtime"], + refresh=refresh, ) # set it as current pipeline p.activate() diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 9b602ed9f4..361f627dfc 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -5,6 +5,7 @@ from dlt.common.typing import AnyFun, TSecretValue from dlt.common.utils import digest256 from dlt.common.data_writers import TLoaderFileFormat +from dlt.pipeline.typing import TRefreshMode @configspec @@ -30,6 +31,10 @@ class PipelineConfiguration(BaseConfiguration): """When set to True, each instance of the pipeline with the `pipeline_name` starts from scratch when run and loads the data to a separate dataset.""" progress: Optional[str] = None runtime: RunConfiguration + refresh: Optional[TRefreshMode] = None + """Refresh mode for the pipeline, use with care. `full` completely wipes pipeline state and data before each run. + `replace` wipes only state and data from the resources selected to run. Default is `None` which means no refresh. + """ def on_resolved(self) -> None: if not self.pipeline_name: diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index 7bba5f84e7..e5c2e25cf8 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -1,5 +1,17 @@ import contextlib -from typing import Callable, Sequence, Iterable, Optional, Any, List, Dict, Tuple, Union, TypedDict +from typing import ( + Callable, + Sequence, + Iterable, + Optional, + Any, + List, + Dict, + Tuple, + Union, + TypedDict, + TYPE_CHECKING, +) from itertools import chain from dlt.common.jsonpath import resolve_paths, TAnyJsonPath, compile_paths @@ -17,6 +29,8 @@ _sources_state, _delete_source_state_keys, _get_matching_resources, + StateInjectableContext, + Container, ) from dlt.common.destination.reference import WithStagingDataset @@ -27,7 +41,10 @@ PipelineHasPendingDataException, ) from dlt.pipeline.typing import TPipelineStep -from dlt.pipeline import Pipeline +from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated + +if TYPE_CHECKING: + from dlt.pipeline import Pipeline def retry_load( @@ -77,13 +94,22 @@ class _DropInfo(TypedDict): class DropCommand: def __init__( self, - pipeline: Pipeline, + pipeline: "Pipeline", resources: Union[Iterable[Union[str, TSimpleRegex]], Union[str, TSimpleRegex]] = (), schema_name: Optional[str] = None, state_paths: TAnyJsonPath = (), drop_all: bool = False, state_only: bool = False, ) -> None: + """ + Args: + pipeline: Pipeline to drop tables and state from + resources: List of resources to drop. If empty, no resources are dropped unless `drop_all` is True + schema_name: Name of the schema to drop tables from. If not specified, the default schema is used + state_paths: JSON path(s) relative to the source state to drop + drop_all: Drop all resources and tables in the schema (supersedes `resources` list) + state_only: Drop only state, not tables + """ self.pipeline = pipeline if isinstance(resources, str): resources = [resources] @@ -187,7 +213,10 @@ def _create_modified_state(self) -> Dict[str, Any]: self.info["resource_states"].append(key) reset_resource_state(key, source_state) # drop additional state paths - resolved_paths = resolve_paths(self.state_paths_to_drop, source_state) + # Don't drop 'resources' key if jsonpath is wildcard + resolved_paths = [ + p for p in resolve_paths(self.state_paths_to_drop, source_state) if p != "resources" + ] if self.state_paths_to_drop and not resolved_paths: self.info["warnings"].append( f"State paths {self.state_paths_to_drop} did not select any paths in source" @@ -202,6 +231,15 @@ def _drop_state_keys(self) -> None: with self.pipeline.managed_state(extract_state=True) as state: # type: ignore[assignment] state.clear() state.update(self._new_state) + try: + # Also update the state in current context if one is active + # so that we can run the pipeline directly after drop in the same process + ctx = Container()[StateInjectableContext] + state = ctx.state # type: ignore[assignment] + state.clear() + state.update(self._new_state) + except ContextDefaultCannotBeCreated: + pass def __call__(self) -> None: if ( @@ -236,7 +274,7 @@ def __call__(self) -> None: def drop( - pipeline: Pipeline, + pipeline: "Pipeline", resources: Union[Iterable[str], str] = (), schema_name: str = None, state_paths: TAnyJsonPath = (), diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 46d1c590f7..87cdb33727 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -124,7 +124,7 @@ end_trace_step, end_trace, ) -from dlt.pipeline.typing import TPipelineStep +from dlt.pipeline.typing import TPipelineStep, TRefreshMode from dlt.pipeline.state_sync import ( STATE_ENGINE_VERSION, bump_version_if_modified, @@ -135,6 +135,7 @@ json_decode_state, ) from dlt.pipeline.warnings import credentials_argument_deprecated +from dlt.pipeline.helpers import drop as drop_command def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: @@ -291,6 +292,7 @@ class Pipeline(SupportsPipeline): collector: _Collector config: PipelineConfiguration runtime_config: RunConfiguration + refresh: Optional[TRefreshMode] = None def __init__( self, @@ -308,6 +310,7 @@ def __init__( must_attach_to_local_pipeline: bool, config: PipelineConfiguration, runtime: RunConfiguration, + refresh: Optional[TRefreshMode] = None, ) -> None: """Initializes the Pipeline class which implements `dlt` pipeline. Please use `pipeline` function in `dlt` module to create a new Pipeline instance.""" self.pipeline_salt = pipeline_salt @@ -317,6 +320,7 @@ def __init__( self.collector = progress or _NULL_COLLECTOR self.destination = None self.staging = None + self.refresh = refresh self._container = Container() self._pipeline_instance_id = self._create_pipeline_instance_id() @@ -386,6 +390,9 @@ def extract( schema_contract: TSchemaContract = None, ) -> ExtractInfo: """Extracts the `data` and prepare it for the normalization. Does not require destination or credentials to be configured. See `run` method for the arguments' description.""" + if self.refresh == "full": + drop_command(self, drop_all=True, state_paths="*") + # create extract storage to which all the sources will be extracted extract_step = Extract( self._schema_storage, @@ -1101,7 +1108,7 @@ def _get_destination_client_initial_config( if issubclass(client_spec, DestinationClientDwhConfiguration): if not self.dataset_name and self.dev_mode: logger.warning( - "Full refresh may not work if dataset name is not set. Please set the" + "Dev mode may not work if dataset name is not set. Please set the" " dataset_name argument in dlt.pipeline or run method" ) # set default schema name to load all incoming data to a single dataset, no matter what is the current schema name diff --git a/dlt/pipeline/typing.py b/dlt/pipeline/typing.py index f0192a504d..ec0eca4685 100644 --- a/dlt/pipeline/typing.py +++ b/dlt/pipeline/typing.py @@ -1,3 +1,5 @@ from typing import Literal TPipelineStep = Literal["sync", "extract", "normalize", "load"] + +TRefreshMode = Literal["full", "replace"]