Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Sep 3, 2024
1 parent 7fa2d54 commit 0d8a85a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion dlt/cli/_dlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def main() -> int:
)
init_cmd.add_argument(
"--location",
default=None,
default=DEFAULT_VERIFIED_SOURCES_REPO,
help="Advanced. Uses a specific url or local path to verified sources repository.",
)
init_cmd.add_argument(
Expand Down
22 changes: 12 additions & 10 deletions dlt/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def _welcome_message(
% (fmt.bold(destination_type), fmt.bold(make_dlt_settings_path(SECRETS_TOML)))
)

if destination_type == "destination":
fmt.echo(
"* You have selected the custom destination as your pipelines destination. Please refer"
" to our docs at https://dlthub.com/docs/dlt-ecosystem/destinations/destination on how"
" to add a destination function that will consume your data."
)

if dependency_system:
fmt.echo("* Add the required dependencies to %s:" % fmt.bold(dependency_system))
compiled_requirements = source_configuration.requirements.compiled()
Expand Down Expand Up @@ -236,30 +243,24 @@ def init_command(
source_name: str,
destination_type: str,
use_generic_template: bool,
repo_location: str = None,
repo_location: str,
branch: str = None,
) -> None:
# try to import the destination and get config spec
destination_reference = Destination.from_reference(destination_type)
destination_spec = destination_reference.spec

# set default repo
explicit_repo_location_provided = repo_location is not None
repo_location = repo_location or DEFAULT_VERIFIED_SOURCES_REPO

# lookup core sources
local_path = Path(os.path.dirname(os.path.realpath(__file__))).parent / SOURCES_MODULE_NAME
local_sources_storage = FileStorage(str(local_path))

# discover type of source
source_type: files_ops.SOURCE_TYPE = "generic"
if (
(
local_sources_storage.has_folder(source_name)
and source_name not in SKIP_CORE_SOURCES_FOLDERS
)
local_sources_storage.has_folder(source_name)
and source_name not in SKIP_CORE_SOURCES_FOLDERS
# NOTE: if explicit repo was passed, we do not use any core sources
and not explicit_repo_location_provided
# and not explicit_repo_location_provided
):
source_type = "core"
else:
Expand Down Expand Up @@ -335,6 +336,7 @@ def init_command(

else:
pipeline_dest_script = source_name + "_pipeline.py"

if source_type == "core":
source_configuration = SourceConfiguration(
source_type,
Expand Down
16 changes: 11 additions & 5 deletions tests/cli/test_init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_init_all_sources_isolated(cloned_init_repo: FileStorage) -> None:
for candidate in source_candidates:
clean_test_storage()
repo_dir = get_repo_dir(cloned_init_repo)
files = get_project_files()
files = get_project_files(clear_all_sources=False)
with set_working_dir(files.storage_path):
init_command.init_command(candidate, "bigquery", False, repo_dir)
assert_source_files(files, candidate, "bigquery")
Expand All @@ -194,13 +194,19 @@ def test_init_all_sources_isolated(cloned_init_repo: FileStorage) -> None:
def test_init_all_destinations(
destination_name: str, project_files: FileStorage, repo_dir: str
) -> None:
if destination_name == "destination":
pytest.skip("Init for generic destination not implemented yet")
source_name = f"generic_{destination_name}"
source_name = "generic"
init_command.init_command(source_name, destination_name, True, repo_dir)
assert_init_files(project_files, source_name + "_pipeline", destination_name)


def test_custom_destination_note(repo_dir: str, project_files: FileStorage):
source_name = "generic"
with io.StringIO() as buf, contextlib.redirect_stdout(buf):
init_command.init_command(source_name, "destination", True, repo_dir)
_out = buf.getvalue()
assert "to add a destination function that will consume your data" in _out


def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) -> None:
sources_storage = FileStorage(os.path.join(repo_dir, SOURCES_MODULE_NAME))
new_content = '"""New docstrings"""'
Expand Down Expand Up @@ -535,7 +541,7 @@ def assert_source_files(
visitor, secrets = assert_common_files(
project_files, source_name + "_pipeline.py", destination_name
)
assert project_files.has_folder(source_name) # == (source_name not in CORE_SOURCES)
assert project_files.has_folder(source_name) == (source_name not in CORE_SOURCES)
source_secrets = secrets.get_value(source_name, type, None, source_name)
if has_source_section:
assert source_secrets is not None
Expand Down
5 changes: 4 additions & 1 deletion tests/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ def get_repo_dir(cloned_init_repo: FileStorage) -> str:
return repo_dir


def get_project_files() -> FileStorage:
def get_project_files(clear_all_sources: bool = True) -> FileStorage:
# we only remove sources registered outside of dlt core
for name, source in _SOURCES.copy().items():
if not source.module.__name__.startswith("dlt.sources"):
_SOURCES.pop(name)

if clear_all_sources:
_SOURCES.clear()

# project dir
return FileStorage(PROJECT_DIR, makedirs=True)

0 comments on commit 0d8a85a

Please sign in to comment.