Skip to content

Commit

Permalink
Added support for user stages in stage execute command (#1284)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astus authored Jul 9, 2024
1 parent 23ba54e commit 059d24b
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 61 deletions.
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
## Fixes and improvements
* The `snow app run` command now allows upgrading to unversioned mode from a versioned or release mode application installation
* The `snow app teardown` command now allows dropping a package with versions when the `--force` flag is provided
* Added support for user stages in stage execute command


# v2.6.0
## Backward incompatibility
Expand Down
47 changes: 28 additions & 19 deletions src/snowflake/cli/plugins/git/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@
from snowflake.connector.cursor import SnowflakeCursor


class GitStagePathParts(StagePathParts):
def __init__(self, stage_path: str):
self.stage = GitManager.get_stage_from_path(stage_path)
stage_path_parts = Path(stage_path).parts
git_repo_name = stage_path_parts[0].split(".")[-1]
if git_repo_name.startswith("@"):
git_repo_name = git_repo_name[1:]
self.stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""])
self.directory = "/".join(stage_path_parts[3:])

@property
def path(self) -> str:
return (
f"{self.stage_name}{self.directory}".lower()
if self.stage_name.endswith("/")
else f"{self.stage_name}/{self.directory}".lower()
)

def add_stage_prefix(self, file_path: str) -> str:
stage = Path(self.stage).parts[0]
file_path_without_prefix = Path(file_path).parts[1:]
return f"{stage}/{'/'.join(file_path_without_prefix)}"


class GitManager(StageManager):
def show_branches(self, repo_name: str, like: str) -> SnowflakeCursor:
return self._execute_query(f"show git branches like '{like}' in {repo_name}")
Expand Down Expand Up @@ -51,22 +75,7 @@ def get_stage_from_path(path: str):
"""
return f"{'/'.join(Path(path).parts[0:3])}/"

def _split_stage_path(self, stage_path: str) -> StagePathParts:
"""
Splits Git repository path `@repo/branch/main/dir`
stage -> @repo/branch/main/
stage_name -> repo/branch/main/
directory -> dir
For Git repository with fully qualified name `@db.schema.repo/branch/main/dir`
stage -> @db.schema.repo/branch/main/
stage_name -> repo/branch/main/
directory -> dir
"""
stage = self.get_stage_from_path(stage_path)
stage_path_parts = Path(stage_path).parts
git_repo_name = stage_path_parts[0].split(".")[-1]
if git_repo_name.startswith("@"):
git_repo_name = git_repo_name[1:]
stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""])
directory = "/".join(stage_path_parts[3:])
return StagePathParts(stage, stage_name, directory)
@staticmethod
def _stage_path_part_factory(stage_path: str) -> StagePathParts:
stage_path = StageManager.get_standard_stage_prefix(stage_path)
return GitStagePathParts(stage_path)
116 changes: 74 additions & 42 deletions src/snowflake/cli/plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,47 @@

UNQUOTED_FILE_URI_REGEX = r"[\w/*?\-.=&{}$#[\]\"\\!@%^+:]+"
EXECUTE_SUPPORTED_FILES_FORMATS = {".sql"}
USER_STAGE_PREFIX = "@~"


@dataclass
class StagePathParts:
# For path like @db.schema.stage/dir the values will be:
# stage = @db.schema.stage
directory: str
stage: str
# stage_name = stage/dir
stage_name: str
# directory = dir
directory: str

@staticmethod
def get_directory(stage_path: str) -> str:
return "/".join(Path(stage_path).parts[1:])

@property
def path(self) -> str:
raise NotImplementedError

def add_stage_prefix(self, file_path: str) -> str:
raise NotImplementedError


@dataclass
class DefaultStagePathParts(StagePathParts):
"""
For path like @db.schema.stage/dir the values will be:
directory = dir
stage = @db.schema.stage
stage_name = stage
For `@stage/dir` to
stage -> @stage
stage_name -> stage
directory -> dir
"""

def __init__(self, stage_path: str):
self.directory = self.get_directory(stage_path)
self.stage = StageManager.get_stage_from_path(stage_path)
stage_name = self.stage.split(".")[-1]
if stage_name.startswith("@"):
stage_name = stage_name[1:]
self.stage_name = stage_name

@property
def path(self) -> str:
Expand All @@ -59,6 +89,33 @@ def path(self) -> str:
else f"{self.stage_name}/{self.directory}".lower()
)

def add_stage_prefix(self, file_path: str) -> str:
stage = Path(self.stage).parts[0]
file_path_without_prefix = Path(file_path).parts[1:]
return f"{stage}/{'/'.join(file_path_without_prefix)}"


@dataclass
class UserStagePathParts(StagePathParts):
"""
For path like @db.schema.stage/dir the values will be:
directory = dir
stage = @~
stage_name = @~
"""

def __init__(self, stage_path: str):
self.directory = self.get_directory(stage_path)
self.stage = "@~"
self.stage_name = "@~"

@property
def path(self) -> str:
return f"{self.directory}".lower()

def add_stage_prefix(self, file_path: str) -> str:
return f"{self.stage}/{file_path}"


class StageManager(SqlExecutionMixin):
@staticmethod
Expand Down Expand Up @@ -96,12 +153,6 @@ def quote_stage_name(name: str) -> str:

return standard_name

@staticmethod
def remove_stage_prefix(stage_path: str) -> str:
if stage_path.startswith("@"):
return stage_path[1:]
return stage_path

def _to_uri(self, local_path: str):
uri = f"file://{local_path}"
if re.fullmatch(UNQUOTED_FILE_URI_REGEX, uri):
Expand Down Expand Up @@ -217,8 +268,7 @@ def execute(
on_error: OnErrorType,
variables: Optional[List[str]] = None,
):
stage_path_with_prefix = self.get_standard_stage_prefix(stage_path)
stage_path_parts = self._split_stage_path(stage_path_with_prefix)
stage_path_parts = self._stage_path_part_factory(stage_path)
all_files_list = self._get_files_list_from_stage(stage_path_parts)

# filter files from stage if match stage_path pattern
Expand All @@ -228,42 +278,24 @@ def execute(
raise ClickException(f"No files matched pattern '{stage_path}'")

# sort filtered files in alphabetical order with directories at the end
sorted_file_list = sorted(
sorted_file_path_list = sorted(
filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f))
)

sql_variables = self._parse_execute_variables(variables)
results = []
for file in sorted_file_list:
for file_path in sorted_file_path_list:
results.append(
self._call_execute_immediate(
stage_path_parts=stage_path_parts,
file=file,
file_path=file_path,
variables=sql_variables,
on_error=on_error,
)
)

return results

def _split_stage_path(self, stage_path: str) -> StagePathParts:
"""
Splits stage path `@stage/dir` to
stage -> @stage
stage_name -> stage
directory -> dir
For stage path with fully qualified name `@db.schema.stage/dir`
stage -> @db.schema.stage
stage_name -> stage
directory -> dir
"""
stage = self.get_stage_from_path(stage_path)
stage_name = stage.split(".")[-1]
if stage_name.startswith("@"):
stage_name = stage_name[1:]
directory = "/".join(Path(stage_path).parts[1:])
return StagePathParts(stage, stage_name, directory)

def _get_files_list_from_stage(self, stage_path_parts: StagePathParts) -> List[str]:
files_list_result = self.list_files(stage_path_parts.stage).fetchall()

Expand Down Expand Up @@ -314,11 +346,11 @@ def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]:
def _call_execute_immediate(
self,
stage_path_parts: StagePathParts,
file: str,
file_path: str,
variables: Optional[str],
on_error: OnErrorType,
) -> Dict:
file_stage_path = self._build_file_stage_path(stage_path_parts, file)
file_stage_path = stage_path_parts.add_stage_prefix(file_path)
try:
query = f"execute immediate from {file_stage_path}"
if variables:
Expand All @@ -332,9 +364,9 @@ def _call_execute_immediate(
raise e
return {"File": file_stage_path, "Status": "FAILURE", "Error": e.msg}

def _build_file_stage_path(
self, stage_path_parts: StagePathParts, file: str
) -> str:
stage = Path(stage_path_parts.stage).parts[0]
file_path = Path(file).parts[1:]
return f"{stage}/{'/'.join(file_path)}"
@staticmethod
def _stage_path_part_factory(stage_path: str) -> StagePathParts:
stage_path = StageManager.get_standard_stage_prefix(stage_path)
if stage_path.startswith(USER_STAGE_PREFIX):
return UserStagePathParts(stage_path)
return DefaultStagePathParts(stage_path)
61 changes: 61 additions & 0 deletions tests/stage/__snapshots__/test_stage.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,67 @@

'''
# ---
# name: test_execute_from_user_stage[@~-expected_files0]
'''
SUCCESS - @~/s1.sql
SUCCESS - @~/a/s3.sql
SUCCESS - @~/a/b/s4.sql
+---------------------------------+
| File | Status | Error |
|---------------+---------+-------|
| @~/s1.sql | SUCCESS | None |
| @~/a/s3.sql | SUCCESS | None |
| @~/a/b/s4.sql | SUCCESS | None |
+---------------------------------+

'''
# ---
# name: test_execute_from_user_stage[@~/a-expected_files2]
'''
SUCCESS - @~/a/s3.sql
SUCCESS - @~/a/b/s4.sql
+---------------------------------+
| File | Status | Error |
|---------------+---------+-------|
| @~/a/s3.sql | SUCCESS | None |
| @~/a/b/s4.sql | SUCCESS | None |
+---------------------------------+

'''
# ---
# name: test_execute_from_user_stage[@~/a/b-expected_files4]
'''
SUCCESS - @~/a/b/s4.sql
+---------------------------------+
| File | Status | Error |
|---------------+---------+-------|
| @~/a/b/s4.sql | SUCCESS | None |
+---------------------------------+

'''
# ---
# name: test_execute_from_user_stage[@~/a/s3.sql-expected_files3]
'''
SUCCESS - @~/a/s3.sql
+-------------------------------+
| File | Status | Error |
|-------------+---------+-------|
| @~/a/s3.sql | SUCCESS | None |
+-------------------------------+

'''
# ---
# name: test_execute_from_user_stage[@~/s1.sql-expected_files1]
'''
SUCCESS - @~/s1.sql
+-----------------------------+
| File | Status | Error |
|-----------+---------+-------|
| @~/s1.sql | SUCCESS | None |
+-----------------------------+

'''
# ---
# name: test_execute_raise_invalid_file_extension_error
'''
+- Error ----------------------------------------------------------------------+
Expand Down
40 changes: 40 additions & 0 deletions tests/stage/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,46 @@ def test_execute(
assert result.output == snapshot


@pytest.mark.parametrize(
"stage_path, expected_files",
[
("@~", ["@~/s1.sql", "@~/a/s3.sql", "@~/a/b/s4.sql"]),
("@~/s1.sql", ["@~/s1.sql"]),
("@~/a", ["@~/a/s3.sql", "@~/a/b/s4.sql"]),
("@~/a/s3.sql", ["@~/a/s3.sql"]),
("@~/a/b", ["@~/a/b/s4.sql"]),
],
)
@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute_from_user_stage(
mock_execute,
mock_cursor,
runner,
stage_path,
expected_files,
snapshot,
):
mock_execute.return_value = mock_cursor(
[
{"name": "a/s3.sql"},
{"name": "a/b/s4.sql"},
{"name": "s1.sql"},
{"name": "s2"},
],
[],
)

result = runner.invoke(["stage", "execute", stage_path])

assert result.exit_code == 0, result.output
ls_call, *execute_calls = mock_execute.mock_calls
assert ls_call == mock.call(f"ls '@~'", cursor_class=DictCursor)
assert execute_calls == [
mock.call(f"execute immediate from {p}") for p in expected_files
]
assert result.output == snapshot


@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute_with_variables(mock_execute, mock_cursor, runner):
mock_execute.return_value = mock_cursor([{"name": "exe/s1.sql"}], [])
Expand Down
28 changes: 28 additions & 0 deletions tests_integration/__snapshots__/test_stage.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,31 @@
}),
])
# ---
# name: test_user_stage_execute
list([
dict({
'Error': None,
'File': '@~/execute/sql/script1.sql',
'Status': 'SUCCESS',
}),
dict({
'Error': None,
'File': '@~/execute/sql/directory/script2.sql',
'Status': 'SUCCESS',
}),
dict({
'Error': None,
'File': '@~/execute/sql/directory/subdirectory/script3.sql',
'Status': 'SUCCESS',
}),
])
# ---
# name: test_user_stage_execute.1
list([
dict({
'Error': None,
'File': '@~/execute/template/script_template.sql',
'Status': 'SUCCESS',
}),
])
# ---
Loading

0 comments on commit 059d24b

Please sign in to comment.