From 34c872f384729615fbd5f6d879051c6a3618c211 Mon Sep 17 00:00:00 2001 From: Parya Jafari Date: Wed, 4 Dec 2024 11:31:14 -0500 Subject: [PATCH] Add schema getter to StagePathParts (#1915) --- src/snowflake/cli/_plugins/stage/manager.py | 16 +++- tests/stage/test_stage_path.py | 91 +++++++++++++++++++++ 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/src/snowflake/cli/_plugins/stage/manager.py b/src/snowflake/cli/_plugins/stage/manager.py index dbeab38d94..71b4e18cdd 100644 --- a/src/snowflake/cli/_plugins/stage/manager.py +++ b/src/snowflake/cli/_plugins/stage/manager.py @@ -41,7 +41,7 @@ from snowflake.cli.api.console import cli_console from snowflake.cli.api.constants import PYTHON_3_12 from snowflake.cli.api.identifiers import FQN -from snowflake.cli.api.project.util import to_string_literal +from snowflake.cli.api.project.util import extract_schema, to_string_literal from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.api.stage_path import StagePath @@ -86,6 +86,10 @@ def path(self) -> str: def full_path(self) -> str: raise NotImplementedError + @property + def schema(self) -> str | None: + raise NotImplementedError + def replace_stage_prefix(self, file_path: str) -> str: raise NotImplementedError @@ -139,11 +143,15 @@ def __init__(self, stage_path: str): @property def path(self) -> str: - return f"{self.stage_name.rstrip('/')}/{self.directory}" + return f"{self.stage_name.rstrip('/')}/{self.directory}".rstrip("/") @property def full_path(self) -> str: - return f"{self.stage.rstrip('/')}/{self.directory}" + return f"{self.stage.rstrip('/')}/{self.directory}".rstrip("/") + + @property + def schema(self) -> str | None: + return extract_schema(self.stage) def replace_stage_prefix(self, file_path: str) -> str: stage = Path(self.stage).parts[0] @@ -193,7 +201,7 @@ def path(self) -> str: @property def full_path(self) -> str: - return f"{self.stage}/{self.directory}" + return f"{self.stage}/{self.directory}".rstrip("/") def replace_stage_prefix(self, file_path: str) -> str: if Path(file_path).parts[0] == self.stage_name: diff --git a/tests/stage/test_stage_path.py b/tests/stage/test_stage_path.py index 9bc4fe42cc..21c5b2885c 100644 --- a/tests/stage/test_stage_path.py +++ b/tests/stage/test_stage_path.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest +from snowflake.cli._plugins.stage.manager import DefaultStagePathParts from snowflake.cli.api.stage_path import StagePath # (path, is_git_repo) @@ -168,3 +169,93 @@ def test_parent_path(path, is_git_repo): def test_root_path(stage_name, path): stage_path = StagePath.from_stage_str(path) assert stage_path.root_path() == StagePath.from_stage_str(f"@{stage_name}") + + +@pytest.mark.parametrize( + "input_path, path, full_path, schema, stage, stage_name", + [ + ( + "db.test_schema.test_stage", + "test_stage", + "db.test_schema.test_stage", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "db.test_schema.test_stage/subdir", + "test_stage/subdir", + "db.test_schema.test_stage/subdir", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "db.test_schema.test_stage/nested/dir", + "test_stage/nested/dir", + "db.test_schema.test_stage/nested/dir", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "test_schema.test_stage/nested/dir", + "test_stage/nested/dir", + "test_schema.test_stage/nested/dir", + "test_schema", + "test_schema.test_stage", + "test_stage", + ), + ( + "test_schema.test_stage/trailing/", + "test_stage/trailing", + "test_schema.test_stage/trailing", + "test_schema", + "test_schema.test_stage", + "test_stage", + ), + ( + "db.test_schema.test_stage/nested/trailing/", + "test_stage/nested/trailing", + "db.test_schema.test_stage/nested/trailing", + "test_schema", + "db.test_schema.test_stage", + "test_stage", + ), + ( + "test_stage/nested/trailing/", + "test_stage/nested/trailing", + "test_stage/nested/trailing", + None, + "test_stage", + "test_stage", + ), + ("test_stage/", "test_stage", "test_stage", None, "test_stage", "test_stage"), + ( + "test_stage/nested/dir", + "test_stage/nested/dir", + "test_stage/nested/dir", + None, + "test_stage", + "test_stage", + ), + ("test_stage", "test_stage", "test_stage", None, "test_stage", "test_stage"), + ( + "test_stage/dir/", + "test_stage/dir", + "test_stage/dir", + None, + "test_stage", + "test_stage", + ), + ], +) +def test_default_stage_path_parts( + input_path, path, full_path, schema, stage, stage_name +): + stage_path_parts = DefaultStagePathParts(input_path) + assert stage_path_parts.full_path == full_path + assert stage_path_parts.schema == schema + assert stage_path_parts.path == path + assert stage_path_parts.stage == stage + assert stage_path_parts.stage_name == stage_name