Skip to content

Commit

Permalink
Add schema getter to StagePathParts (#1915)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pjafari authored Dec 4, 2024
1 parent fcd360d commit 34c872f
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from snowflake.cli.api.console import cli_console
from snowflake.cli.api.constants import PYTHON_3_12
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.project.util import to_string_literal
from snowflake.cli.api.project.util import extract_schema, to_string_literal
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.api.stage_path import StagePath
Expand Down Expand Up @@ -86,6 +86,10 @@ def path(self) -> str:
def full_path(self) -> str:
raise NotImplementedError

@property
def schema(self) -> str | None:
raise NotImplementedError

def replace_stage_prefix(self, file_path: str) -> str:
raise NotImplementedError

Expand Down Expand Up @@ -139,11 +143,15 @@ def __init__(self, stage_path: str):

@property
def path(self) -> str:
return f"{self.stage_name.rstrip('/')}/{self.directory}"
return f"{self.stage_name.rstrip('/')}/{self.directory}".rstrip("/")

@property
def full_path(self) -> str:
return f"{self.stage.rstrip('/')}/{self.directory}"
return f"{self.stage.rstrip('/')}/{self.directory}".rstrip("/")

@property
def schema(self) -> str | None:
return extract_schema(self.stage)

def replace_stage_prefix(self, file_path: str) -> str:
stage = Path(self.stage).parts[0]
Expand Down Expand Up @@ -193,7 +201,7 @@ def path(self) -> str:

@property
def full_path(self) -> str:
return f"{self.stage}/{self.directory}"
return f"{self.stage}/{self.directory}".rstrip("/")

def replace_stage_prefix(self, file_path: str) -> str:
if Path(file_path).parts[0] == self.stage_name:
Expand Down
91 changes: 91 additions & 0 deletions tests/stage/test_stage_path.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pytest
from snowflake.cli._plugins.stage.manager import DefaultStagePathParts
from snowflake.cli.api.stage_path import StagePath

# (path, is_git_repo)
Expand Down Expand Up @@ -168,3 +169,93 @@ def test_parent_path(path, is_git_repo):
def test_root_path(stage_name, path):
stage_path = StagePath.from_stage_str(path)
assert stage_path.root_path() == StagePath.from_stage_str(f"@{stage_name}")


@pytest.mark.parametrize(
"input_path, path, full_path, schema, stage, stage_name",
[
(
"db.test_schema.test_stage",
"test_stage",
"db.test_schema.test_stage",
"test_schema",
"db.test_schema.test_stage",
"test_stage",
),
(
"db.test_schema.test_stage/subdir",
"test_stage/subdir",
"db.test_schema.test_stage/subdir",
"test_schema",
"db.test_schema.test_stage",
"test_stage",
),
(
"db.test_schema.test_stage/nested/dir",
"test_stage/nested/dir",
"db.test_schema.test_stage/nested/dir",
"test_schema",
"db.test_schema.test_stage",
"test_stage",
),
(
"test_schema.test_stage/nested/dir",
"test_stage/nested/dir",
"test_schema.test_stage/nested/dir",
"test_schema",
"test_schema.test_stage",
"test_stage",
),
(
"test_schema.test_stage/trailing/",
"test_stage/trailing",
"test_schema.test_stage/trailing",
"test_schema",
"test_schema.test_stage",
"test_stage",
),
(
"db.test_schema.test_stage/nested/trailing/",
"test_stage/nested/trailing",
"db.test_schema.test_stage/nested/trailing",
"test_schema",
"db.test_schema.test_stage",
"test_stage",
),
(
"test_stage/nested/trailing/",
"test_stage/nested/trailing",
"test_stage/nested/trailing",
None,
"test_stage",
"test_stage",
),
("test_stage/", "test_stage", "test_stage", None, "test_stage", "test_stage"),
(
"test_stage/nested/dir",
"test_stage/nested/dir",
"test_stage/nested/dir",
None,
"test_stage",
"test_stage",
),
("test_stage", "test_stage", "test_stage", None, "test_stage", "test_stage"),
(
"test_stage/dir/",
"test_stage/dir",
"test_stage/dir",
None,
"test_stage",
"test_stage",
),
],
)
def test_default_stage_path_parts(
input_path, path, full_path, schema, stage, stage_name
):
stage_path_parts = DefaultStagePathParts(input_path)
assert stage_path_parts.full_path == full_path
assert stage_path_parts.schema == schema
assert stage_path_parts.path == path
assert stage_path_parts.stage == stage
assert stage_path_parts.stage_name == stage_name

0 comments on commit 34c872f

Please sign in to comment.