diff --git a/tests/stage/test_stage_path.py b/tests/stage/test_stage_path.py index 9bc4fe42cc..21c5b2885c 100644 --- a/tests/stage/test_stage_path.py +++ b/tests/stage/test_stage_path.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest +from snowflake.cli._plugins.stage.manager import DefaultStagePathParts from snowflake.cli.api.stage_path import StagePath # (path, is_git_repo) @@ -168,3 +169,93 @@ def test_parent_path(path, is_git_repo): def test_root_path(stage_name, path): stage_path = StagePath.from_stage_str(path) assert stage_path.root_path() == StagePath.from_stage_str(f"@{stage_name}") + + +@pytest.mark.parametrize( + "input_path, path, full_path, schema, stage, stage_name", + [ + ( + "db.test_schema.test_stage", + "test_stage", + "db.test_schema.test_stage", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "db.test_schema.test_stage/subdir", + "test_stage/subdir", + "db.test_schema.test_stage/subdir", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "db.test_schema.test_stage/nested/dir", + "test_stage/nested/dir", + "db.test_schema.test_stage/nested/dir", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "test_schema.test_stage/nested/dir", + "test_stage/nested/dir", + "test_schema.test_stage/nested/dir", + "test_schema", + "test_schema.test_stage", + "test_stage", + ), + ( + "test_schema.test_stage/trailing/", + "test_stage/trailing", + "test_schema.test_stage/trailing", + "test_schema", + "test_schema.test_stage", + "test_stage", + ), + ( + "db.test_schema.test_stage/nested/trailing/", + "test_stage/nested/trailing", + "db.test_schema.test_stage/nested/trailing", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "test_stage/nested/trailing/", + "test_stage/nested/trailing", + "test_stage/nested/trailing", + None, + "test_stage", + "test_stage", + ), + ("test_stage/", "test_stage", "test_stage", None, "test_stage", "test_stage"), + ( + "test_stage/nested/dir", + "test_stage/nested/dir", + "test_stage/nested/dir", + None, + "test_stage", + "test_stage", + ), + ("test_stage", "test_stage", "test_stage", None, "test_stage", "test_stage"), + ( + "test_stage/dir/", + "test_stage/dir", + "test_stage/dir", + None, + "test_stage", + "test_stage", + ), + ], +) +def test_default_stage_path_parts( + input_path, path, full_path, schema, stage, stage_name +): + stage_path_parts = DefaultStagePathParts(input_path) + assert stage_path_parts.full_path == full_path + assert stage_path_parts.schema == schema + assert stage_path_parts.path == path + assert stage_path_parts.stage == stage + assert stage_path_parts.stage_name == stage_name