diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 13381a85d4..d0927ad250 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -24,6 +24,8 @@ ## Fixes and improvements * The `snow app run` command now allows upgrading to unversioned mode from a versioned or release mode application installation * The `snow app teardown` command now allows dropping a package with versions when the `--force` flag is provided +* Added support for user stages in stage execute command + # v2.6.0 ## Backward incompatibility diff --git a/src/snowflake/cli/plugins/git/manager.py b/src/snowflake/cli/plugins/git/manager.py index d1e4d1887f..5407a808c3 100644 --- a/src/snowflake/cli/plugins/git/manager.py +++ b/src/snowflake/cli/plugins/git/manager.py @@ -19,6 +19,30 @@ from snowflake.connector.cursor import SnowflakeCursor +class GitStagePathParts(StagePathParts): + def __init__(self, stage_path: str): + self.stage = GitManager.get_stage_from_path(stage_path) + stage_path_parts = Path(stage_path).parts + git_repo_name = stage_path_parts[0].split(".")[-1] + if git_repo_name.startswith("@"): + git_repo_name = git_repo_name[1:] + self.stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""]) + self.directory = "/".join(stage_path_parts[3:]) + + @property + def path(self) -> str: + return ( + f"{self.stage_name}{self.directory}".lower() + if self.stage_name.endswith("/") + else f"{self.stage_name}/{self.directory}".lower() + ) + + def add_stage_prefix(self, file_path: str) -> str: + stage = Path(self.stage).parts[0] + file_path_without_prefix = Path(file_path).parts[1:] + return f"{stage}/{'/'.join(file_path_without_prefix)}" + + class GitManager(StageManager): def show_branches(self, repo_name: str, like: str) -> SnowflakeCursor: return self._execute_query(f"show git branches like '{like}' in {repo_name}") @@ -51,22 +75,7 @@ def get_stage_from_path(path: str): """ return f"{'/'.join(Path(path).parts[0:3])}/" - def _split_stage_path(self, stage_path: str) -> StagePathParts: - """ - Splits Git repository path `@repo/branch/main/dir` - stage -> @repo/branch/main/ - stage_name -> repo/branch/main/ - directory -> dir - For Git repository with fully qualified name `@db.schema.repo/branch/main/dir` - stage -> @db.schema.repo/branch/main/ - stage_name -> repo/branch/main/ - directory -> dir - """ - stage = self.get_stage_from_path(stage_path) - stage_path_parts = Path(stage_path).parts - git_repo_name = stage_path_parts[0].split(".")[-1] - if git_repo_name.startswith("@"): - git_repo_name = git_repo_name[1:] - stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""]) - directory = "/".join(stage_path_parts[3:]) - return StagePathParts(stage, stage_name, directory) + @staticmethod + def _stage_path_part_factory(stage_path: str) -> StagePathParts: + stage_path = StageManager.get_standard_stage_prefix(stage_path) + return GitStagePathParts(stage_path) diff --git a/src/snowflake/cli/plugins/stage/manager.py b/src/snowflake/cli/plugins/stage/manager.py index 0f7a9054b1..25ffc33272 100644 --- a/src/snowflake/cli/plugins/stage/manager.py +++ b/src/snowflake/cli/plugins/stage/manager.py @@ -39,17 +39,47 @@ UNQUOTED_FILE_URI_REGEX = r"[\w/*?\-.=&{}$#[\]\"\\!@%^+:]+" EXECUTE_SUPPORTED_FILES_FORMATS = {".sql"} +USER_STAGE_PREFIX = "@~" @dataclass class StagePathParts: - # For path like @db.schema.stage/dir the values will be: - # stage = @db.schema.stage + directory: str stage: str - # stage_name = stage/dir stage_name: str - # directory = dir - directory: str + + @staticmethod + def get_directory(stage_path: str) -> str: + return "/".join(Path(stage_path).parts[1:]) + + @property + def path(self) -> str: + raise NotImplementedError + + def add_stage_prefix(self, file_path: str) -> str: + raise NotImplementedError + + +@dataclass +class DefaultStagePathParts(StagePathParts): + """ + For path like @db.schema.stage/dir the values will be: + directory = dir + stage = @db.schema.stage + stage_name = stage + For `@stage/dir` to + stage -> @stage + stage_name -> stage + directory -> dir + """ + + def __init__(self, stage_path: str): + self.directory = self.get_directory(stage_path) + self.stage = StageManager.get_stage_from_path(stage_path) + stage_name = self.stage.split(".")[-1] + if stage_name.startswith("@"): + stage_name = stage_name[1:] + self.stage_name = stage_name @property def path(self) -> str: @@ -59,6 +89,33 @@ def path(self) -> str: else f"{self.stage_name}/{self.directory}".lower() ) + def add_stage_prefix(self, file_path: str) -> str: + stage = Path(self.stage).parts[0] + file_path_without_prefix = Path(file_path).parts[1:] + return f"{stage}/{'/'.join(file_path_without_prefix)}" + + +@dataclass +class UserStagePathParts(StagePathParts): + """ + For path like @db.schema.stage/dir the values will be: + directory = dir + stage = @~ + stage_name = @~ + """ + + def __init__(self, stage_path: str): + self.directory = self.get_directory(stage_path) + self.stage = "@~" + self.stage_name = "@~" + + @property + def path(self) -> str: + return f"{self.directory}".lower() + + def add_stage_prefix(self, file_path: str) -> str: + return f"{self.stage}/{file_path}" + class StageManager(SqlExecutionMixin): @staticmethod @@ -96,12 +153,6 @@ def quote_stage_name(name: str) -> str: return standard_name - @staticmethod - def remove_stage_prefix(stage_path: str) -> str: - if stage_path.startswith("@"): - return stage_path[1:] - return stage_path - def _to_uri(self, local_path: str): uri = f"file://{local_path}" if re.fullmatch(UNQUOTED_FILE_URI_REGEX, uri): @@ -217,8 +268,7 @@ def execute( on_error: OnErrorType, variables: Optional[List[str]] = None, ): - stage_path_with_prefix = self.get_standard_stage_prefix(stage_path) - stage_path_parts = self._split_stage_path(stage_path_with_prefix) + stage_path_parts = self._stage_path_part_factory(stage_path) all_files_list = self._get_files_list_from_stage(stage_path_parts) # filter files from stage if match stage_path pattern @@ -228,17 +278,17 @@ def execute( raise ClickException(f"No files matched pattern '{stage_path}'") # sort filtered files in alphabetical order with directories at the end - sorted_file_list = sorted( + sorted_file_path_list = sorted( filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f)) ) sql_variables = self._parse_execute_variables(variables) results = [] - for file in sorted_file_list: + for file_path in sorted_file_path_list: results.append( self._call_execute_immediate( stage_path_parts=stage_path_parts, - file=file, + file_path=file_path, variables=sql_variables, on_error=on_error, ) @@ -246,24 +296,6 @@ def execute( return results - def _split_stage_path(self, stage_path: str) -> StagePathParts: - """ - Splits stage path `@stage/dir` to - stage -> @stage - stage_name -> stage - directory -> dir - For stage path with fully qualified name `@db.schema.stage/dir` - stage -> @db.schema.stage - stage_name -> stage - directory -> dir - """ - stage = self.get_stage_from_path(stage_path) - stage_name = stage.split(".")[-1] - if stage_name.startswith("@"): - stage_name = stage_name[1:] - directory = "/".join(Path(stage_path).parts[1:]) - return StagePathParts(stage, stage_name, directory) - def _get_files_list_from_stage(self, stage_path_parts: StagePathParts) -> List[str]: files_list_result = self.list_files(stage_path_parts.stage).fetchall() @@ -314,11 +346,11 @@ def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]: def _call_execute_immediate( self, stage_path_parts: StagePathParts, - file: str, + file_path: str, variables: Optional[str], on_error: OnErrorType, ) -> Dict: - file_stage_path = self._build_file_stage_path(stage_path_parts, file) + file_stage_path = stage_path_parts.add_stage_prefix(file_path) try: query = f"execute immediate from {file_stage_path}" if variables: @@ -332,9 +364,9 @@ def _call_execute_immediate( raise e return {"File": file_stage_path, "Status": "FAILURE", "Error": e.msg} - def _build_file_stage_path( - self, stage_path_parts: StagePathParts, file: str - ) -> str: - stage = Path(stage_path_parts.stage).parts[0] - file_path = Path(file).parts[1:] - return f"{stage}/{'/'.join(file_path)}" + @staticmethod + def _stage_path_part_factory(stage_path: str) -> StagePathParts: + stage_path = StageManager.get_standard_stage_prefix(stage_path) + if stage_path.startswith(USER_STAGE_PREFIX): + return UserStagePathParts(stage_path) + return DefaultStagePathParts(stage_path) diff --git a/tests/stage/__snapshots__/test_stage.ambr b/tests/stage/__snapshots__/test_stage.ambr index dc9a8d7cf4..a4b189c607 100644 --- a/tests/stage/__snapshots__/test_stage.ambr +++ b/tests/stage/__snapshots__/test_stage.ambr @@ -272,6 +272,67 @@ ''' # --- +# name: test_execute_from_user_stage[@~-expected_files0] + ''' + SUCCESS - @~/s1.sql + SUCCESS - @~/a/s3.sql + SUCCESS - @~/a/b/s4.sql + +---------------------------------+ + | File | Status | Error | + |---------------+---------+-------| + | @~/s1.sql | SUCCESS | None | + | @~/a/s3.sql | SUCCESS | None | + | @~/a/b/s4.sql | SUCCESS | None | + +---------------------------------+ + + ''' +# --- +# name: test_execute_from_user_stage[@~/a-expected_files2] + ''' + SUCCESS - @~/a/s3.sql + SUCCESS - @~/a/b/s4.sql + +---------------------------------+ + | File | Status | Error | + |---------------+---------+-------| + | @~/a/s3.sql | SUCCESS | None | + | @~/a/b/s4.sql | SUCCESS | None | + +---------------------------------+ + + ''' +# --- +# name: test_execute_from_user_stage[@~/a/b-expected_files4] + ''' + SUCCESS - @~/a/b/s4.sql + +---------------------------------+ + | File | Status | Error | + |---------------+---------+-------| + | @~/a/b/s4.sql | SUCCESS | None | + +---------------------------------+ + + ''' +# --- +# name: test_execute_from_user_stage[@~/a/s3.sql-expected_files3] + ''' + SUCCESS - @~/a/s3.sql + +-------------------------------+ + | File | Status | Error | + |-------------+---------+-------| + | @~/a/s3.sql | SUCCESS | None | + +-------------------------------+ + + ''' +# --- +# name: test_execute_from_user_stage[@~/s1.sql-expected_files1] + ''' + SUCCESS - @~/s1.sql + +-----------------------------+ + | File | Status | Error | + |-----------+---------+-------| + | @~/s1.sql | SUCCESS | None | + +-----------------------------+ + + ''' +# --- # name: test_execute_raise_invalid_file_extension_error ''' +- Error ----------------------------------------------------------------------+ diff --git a/tests/stage/test_stage.py b/tests/stage/test_stage.py index 047912f0fd..b9c1d97764 100644 --- a/tests/stage/test_stage.py +++ b/tests/stage/test_stage.py @@ -738,6 +738,46 @@ def test_execute( assert result.output == snapshot +@pytest.mark.parametrize( + "stage_path, expected_files", + [ + ("@~", ["@~/s1.sql", "@~/a/s3.sql", "@~/a/b/s4.sql"]), + ("@~/s1.sql", ["@~/s1.sql"]), + ("@~/a", ["@~/a/s3.sql", "@~/a/b/s4.sql"]), + ("@~/a/s3.sql", ["@~/a/s3.sql"]), + ("@~/a/b", ["@~/a/b/s4.sql"]), + ], +) +@mock.patch(f"{STAGE_MANAGER}._execute_query") +def test_execute_from_user_stage( + mock_execute, + mock_cursor, + runner, + stage_path, + expected_files, + snapshot, +): + mock_execute.return_value = mock_cursor( + [ + {"name": "a/s3.sql"}, + {"name": "a/b/s4.sql"}, + {"name": "s1.sql"}, + {"name": "s2"}, + ], + [], + ) + + result = runner.invoke(["stage", "execute", stage_path]) + + assert result.exit_code == 0, result.output + ls_call, *execute_calls = mock_execute.mock_calls + assert ls_call == mock.call(f"ls '@~'", cursor_class=DictCursor) + assert execute_calls == [ + mock.call(f"execute immediate from {p}") for p in expected_files + ] + assert result.output == snapshot + + @mock.patch(f"{STAGE_MANAGER}._execute_query") def test_execute_with_variables(mock_execute, mock_cursor, runner): mock_execute.return_value = mock_cursor([{"name": "exe/s1.sql"}], []) diff --git a/tests_integration/__snapshots__/test_stage.ambr b/tests_integration/__snapshots__/test_stage.ambr index 15c1b810f8..f9a443fddd 100644 --- a/tests_integration/__snapshots__/test_stage.ambr +++ b/tests_integration/__snapshots__/test_stage.ambr @@ -77,3 +77,31 @@ }), ]) # --- +# name: test_user_stage_execute + list([ + dict({ + 'Error': None, + 'File': '@~/execute/sql/script1.sql', + 'Status': 'SUCCESS', + }), + dict({ + 'Error': None, + 'File': '@~/execute/sql/directory/script2.sql', + 'Status': 'SUCCESS', + }), + dict({ + 'Error': None, + 'File': '@~/execute/sql/directory/subdirectory/script3.sql', + 'Status': 'SUCCESS', + }), + ]) +# --- +# name: test_user_stage_execute.1 + list([ + dict({ + 'Error': None, + 'File': '@~/execute/template/script_template.sql', + 'Status': 'SUCCESS', + }), + ]) +# --- diff --git a/tests_integration/test_stage.py b/tests_integration/test_stage.py index 1562b47bd4..c0f0288c17 100644 --- a/tests_integration/test_stage.py +++ b/tests_integration/test_stage.py @@ -229,6 +229,68 @@ def test_stage_execute(runner, test_database, test_root_path, snapshot): ] +@pytest.mark.integration +def test_user_stage_execute(runner, test_database, test_root_path, snapshot): + project_path = test_root_path / "test_data/projects/stage_execute" + user_stage_name = "@~" + + files = [ + ("script1.sql", "execute/sql"), + ("script2.sql", "execute/sql/directory"), + ("script3.sql", "execute/sql/directory/subdirectory"), + ] + for name, stage_path in files: + result = runner.invoke_with_connection_json( + [ + "stage", + "copy", + f"{project_path}/{name}", + f"{user_stage_name}/{stage_path}", + ] + ) + assert result.exit_code == 0, result.output + assert contains_row_with( + result.json, {"status": "SKIPPED"} + ) or contains_row_with(result.json, {"status": "UPLOADED"}) + + result = runner.invoke_with_connection_json( + ["stage", "execute", f"{user_stage_name}/execute/sql"] + ) + assert result.exit_code == 0 + assert result.json == snapshot + + result = runner.invoke_with_connection_json( + [ + "stage", + "copy", + f"{project_path}/script_template.sql", + f"{user_stage_name}/execute/template", + ] + ) + assert result.exit_code == 0, result.output + assert contains_row_with(result.json, {"status": "SKIPPED"}) or contains_row_with( + result.json, {"status": "UPLOADED"} + ) + + result = runner.invoke_with_connection_json( + [ + "stage", + "execute", + f"{user_stage_name}/execute/template/script_template.sql", + "-D", + " text = 'string' ", + "-D", + "value=1", + "-D", + "boolean=TRUE", + "-D", + "null_value= NULL", + ] + ) + assert result.exit_code == 0 + assert result.json == snapshot + + @pytest.mark.integration def test_stage_diff(runner, snowflake_session, test_database, tmp_path, snapshot): stage_name = "test_stage"