diff --git a/dlt/cli/_dlt.py b/dlt/cli/_dlt.py index 7c6526c0a2..a19b4ed4a9 100644 --- a/dlt/cli/_dlt.py +++ b/dlt/cli/_dlt.py @@ -57,9 +57,17 @@ def init_command_wrapper( use_generic_template: bool, repo_location: str, branch: str, + omit_core_sources: bool = False, ) -> int: try: - init_command(source_name, destination_type, use_generic_template, repo_location, branch) + init_command( + source_name, + destination_type, + use_generic_template, + repo_location, + branch, + omit_core_sources, + ) except Exception as ex: on_exception(ex, DLT_INIT_DOCS_URL) return -1 @@ -345,6 +353,16 @@ def main() -> int: ), ) + init_cmd.add_argument( + "--omit-core-sources", + default=False, + action="store_true", + help=( + "When present, will not create the new pipeline with a core source of the given name" + " but will take a source of this name from the default or provided location." + ), + ) + # deploy command requires additional dependencies try: # make sure the name is defined @@ -596,7 +614,12 @@ def main() -> int: return -1 else: return init_command_wrapper( - args.source, args.destination, args.generic, args.location, args.branch + args.source, + args.destination, + args.generic, + args.location, + args.branch, + args.omit_core_sources, ) elif args.command == "deploy": try: diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index 9762aed41c..7f4e223186 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -245,6 +245,7 @@ def init_command( use_generic_template: bool, repo_location: str, branch: str = None, + omit_core_sources: bool = False, ) -> None: # try to import the destination and get config spec destination_reference = Destination.from_reference(destination_type) @@ -259,11 +260,12 @@ def init_command( if ( local_sources_storage.has_folder(source_name) and source_name not in SKIP_CORE_SOURCES_FOLDERS - # NOTE: if explicit repo was passed, we do not use any core sources - # and not explicit_repo_location_provided + and not omit_core_sources ): source_type = "core" else: + if omit_core_sources: + fmt.echo("Omitting dlt core sources.") fmt.echo("Looking up verified sources at %s..." % fmt.bold(repo_location)) clone_storage = git.get_fresh_repo_files(repo_location, get_dlt_repos_dir(), branch=branch) # copy dlt source files from here @@ -502,7 +504,8 @@ def init_command( ) fmt.echo( "NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from the" - " verified sources repo but imported from dlt.sources." % (fmt.bold(source_name)) + " verified sources repo but imported from dlt.sources. You can provide the" + " --omit-core-sources flag to revert to the old behavior." % (fmt.bold(source_name)) ) elif source_configuration.source_type == "verified": fmt.echo( diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index e1bcd1a8f2..a69c9885c8 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -207,6 +207,25 @@ def test_custom_destination_note(repo_dir: str, project_files: FileStorage): assert "to add a destination function that will consume your data" in _out +@pytest.mark.parametrize("omit", [True, False]) +# this will break if we have new core sources that are not in verified sources anymore +@pytest.mark.parametrize("source", CORE_SOURCES) +def test_omit_core_sources( + source: str, omit: bool, project_files: FileStorage, repo_dir: str +) -> None: + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + init_command.init_command(source, "destination", True, repo_dir, omit_core_sources=omit) + _out = buf.getvalue() + + # check messaging + assert ("Omitting dlt core sources" in _out) == omit + assert ("will no longer be copied from the" in _out) == (not omit) + + # if we omit core sources, there will be a folder with the name of the source from the verified sources repo + assert project_files.has_folder(source) == omit + assert (f"dlt.sources.{source}" in project_files.load(f"{source}_pipeline.py")) == (not omit) + + def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) -> None: sources_storage = FileStorage(os.path.join(repo_dir, SOURCES_MODULE_NAME)) new_content = '"""New docstrings"""'