From 78c11ac7fc593b5bd8875ac75d6b61c914e0c7c0 Mon Sep 17 00:00:00 2001 From: David Hung Date: Thu, 4 Apr 2024 11:21:49 -0700 Subject: [PATCH] Add support for `snow://` stage paths (#1346) * Add support for snow:// stage prefix * Fix file name extraction from single quoted paths * Apply snow:// support to put* path * Update CHANGELOG * Move path splitting to helper method * Add unit tests for updated normalize_path and split_path utils --- CHANGELOG.md | 2 + src/snowflake/snowpark/_internal/utils.py | 19 +++++-- src/snowflake/snowpark/file_operation.py | 5 +- tests/unit/test_internal_utils.py | 67 +++++++++++++++++++++++ 4 files changed, 87 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_internal_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e57cec3c999..f4aedc26ec5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,10 +37,12 @@ - show - snowflake.snowpark.DataFrameWriter: - save_as_table +- Added support for snow:// URLs to `snowflake.snowpark.Session.file.get` and `snowflake.snowpark.Session.file.get_stream` ### Bug Fixes - Fixed a bug in local testing that null filled columns for constant functions. +- Fixed a bug causing `snowflake.snowpark.Session.file.get_stream` to fail for quoted stage locations ## 1.14.0 (2024-03-20) diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index c1e1f226dea..20efe074f30 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -53,6 +53,11 @@ ResultMetadataV2 = ResultMetadata STAGE_PREFIX = "@" +SNOWURL_PREFIX = "snow://" +SNOWFLAKE_PATH_PREFIXES = [ + STAGE_PREFIX, + SNOWURL_PREFIX, +] # Scala uses 3 but this can be larger. Consider allowing users to configure it. QUERY_TAG_TRACEBACK_LIMIT = 3 @@ -260,14 +265,14 @@ def normalize_path(path: str, is_local: bool) -> str: a directory named "load data". Therefore, if `path` is already wrapped by single quotes, we do nothing. """ - symbol = "file://" if is_local else STAGE_PREFIX + prefixes = ["file://"] if is_local else SNOWFLAKE_PATH_PREFIXES if is_single_quoted(path): return path if is_local and OPERATING_SYSTEM == "Windows": path = path.replace("\\", "/") path = path.strip().replace("'", "\\'") - if not path.startswith(symbol): - path = f"{symbol}{path}" + if not any(path.startswith(prefix) for prefix in prefixes): + path = f"{prefixes[0]}{path}" return f"'{path}'" @@ -279,9 +284,15 @@ def normalize_local_file(file: str) -> str: return normalize_path(file, is_local=True) +def split_path(path: str) -> Tuple[str, str]: + """Split a file path into directory and file name.""" + path = unwrap_single_quote(path) + return path.rsplit("/", maxsplit=1) + + def unwrap_stage_location_single_quote(name: str) -> str: new_name = unwrap_single_quote(name) - if new_name.startswith(STAGE_PREFIX): + if any(new_name.startswith(prefix) for prefix in SNOWFLAKE_PATH_PREFIXES): return new_name return f"{STAGE_PREFIX}{new_name}" diff --git a/src/snowflake/snowpark/file_operation.py b/src/snowflake/snowpark/file_operation.py index 784e77a678d..46617870913 100644 --- a/src/snowflake/snowpark/file_operation.py +++ b/src/snowflake/snowpark/file_operation.py @@ -18,6 +18,7 @@ normalize_local_file, normalize_remote_file_or_dir, result_set_to_rows, + split_path, ) @@ -273,7 +274,7 @@ def put_stream( ) raise ne.with_traceback(tb) from None else: - stage_with_prefix, dest_filename = stage_location.rsplit("/", maxsplit=1) + stage_with_prefix, dest_filename = split_path(stage_location) put_result = self._session._conn.upload_stream( input_stream=input_stream, stage_location=stage_with_prefix, @@ -338,7 +339,7 @@ def get_stream( else: options = {"parallel": parallel} tmp_dir = tempfile.gettempdir() - src_file_name = stage_location.rsplit("/", maxsplit=1)[1] + src_file_name = split_path(stage_location)[1] local_file_name = os.path.join(tmp_dir, src_file_name) plan = self._session._plan_builder.file_operation_plan( "get", diff --git a/tests/unit/test_internal_utils.py b/tests/unit/test_internal_utils.py new file mode 100644 index 00000000000..3c75a711888 --- /dev/null +++ b/tests/unit/test_internal_utils.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +from snowflake.snowpark._internal import utils + + +@pytest.mark.parametrize( + "path, expected_dir, expected_file", + [ + ("stage/", "stage", ""), + ("stage/file.txt", "stage", "file.txt"), + ("dir/subdir/file.txt", "dir/subdir", "file.txt"), + ("@stage/dir/subdir/file.txt", "@stage/dir/subdir", "file.txt"), + ("'@stage/dir/subdir/file.txt'", "@stage/dir/subdir", "file.txt"), + ( + "snow://domain/test_entity/versions/test_version/file.txt", + "snow://domain/test_entity/versions/test_version", + "file.txt", + ), + ( + "'snow://domain/test_entity/versions/test_version/file.txt'", + "snow://domain/test_entity/versions/test_version", + "file.txt", + ), + ], +) +def test_split_path(path: str, expected_dir: str, expected_file: str) -> None: + dir, file = utils.split_path(path) + assert expected_dir == dir + assert expected_file == file + + +@pytest.mark.parametrize( + "path, is_local, expected", + [ + ("dir/file.txt", True, "'file://dir/file.txt'"), + ("dir/subdir/file.txt", True, "'file://dir/subdir/file.txt'"), + ("'dir/subdir/file.txt'", True, "'dir/subdir/file.txt'"), + ("file://dir/subdir/file.txt", True, "'file://dir/subdir/file.txt'"), + ("stage/", False, "'@stage/'"), + ("stage/file.txt", False, "'@stage/file.txt'"), + ("'stage/file.txt'", False, "'stage/file.txt'"), + ( + "stage/'embedded_quote'/file.txt", + False, + "'@stage/\\'embedded_quote\\'/file.txt'", + ), + ("@stage/dir/subdir/file.txt", False, "'@stage/dir/subdir/file.txt'"), + ("'@stage/dir/subdir/file.txt'", False, "'@stage/dir/subdir/file.txt'"), + ( + "snow://domain/test_entity/versions/test_version/file.txt", + False, + "'snow://domain/test_entity/versions/test_version/file.txt'", + ), + ( + "'snow://domain/test_entity/versions/test_version/file.txt'", + False, + "'snow://domain/test_entity/versions/test_version/file.txt'", + ), + ], +) +def test_normalize_path(path: str, is_local: bool, expected: str) -> None: + actual = utils.normalize_path(path, is_local) + assert expected == actual