Skip to content

Commit

Permalink
Add support for snow:// stage paths (#1346)
Browse files Browse the repository at this point in the history
* Add support for snow:// stage prefix

* Fix file name extraction from single quoted paths

* Apply snow:// support to put* path

* Update CHANGELOG

* Move path splitting to helper method

* Add unit tests for updated normalize_path and split_path utils
  • Loading branch information
sfc-gh-dhung authored Apr 4, 2024
1 parent 96da8e9 commit 78c11ac
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
- show
- snowflake.snowpark.DataFrameWriter:
- save_as_table
- Added support for snow:// URLs to `snowflake.snowpark.Session.file.get` and `snowflake.snowpark.Session.file.get_stream`

### Bug Fixes

- Fixed a bug in local testing that null filled columns for constant functions.
- Fixed a bug causing `snowflake.snowpark.Session.file.get_stream` to fail for quoted stage locations

## 1.14.0 (2024-03-20)

Expand Down
19 changes: 15 additions & 4 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@
ResultMetadataV2 = ResultMetadata

STAGE_PREFIX = "@"
SNOWURL_PREFIX = "snow://"
SNOWFLAKE_PATH_PREFIXES = [
STAGE_PREFIX,
SNOWURL_PREFIX,
]

# Scala uses 3 but this can be larger. Consider allowing users to configure it.
QUERY_TAG_TRACEBACK_LIMIT = 3
Expand Down Expand Up @@ -260,14 +265,14 @@ def normalize_path(path: str, is_local: bool) -> str:
a directory named "load data". Therefore, if `path` is already wrapped by single quotes,
we do nothing.
"""
symbol = "file://" if is_local else STAGE_PREFIX
prefixes = ["file://"] if is_local else SNOWFLAKE_PATH_PREFIXES
if is_single_quoted(path):
return path
if is_local and OPERATING_SYSTEM == "Windows":
path = path.replace("\\", "/")
path = path.strip().replace("'", "\\'")
if not path.startswith(symbol):
path = f"{symbol}{path}"
if not any(path.startswith(prefix) for prefix in prefixes):
path = f"{prefixes[0]}{path}"
return f"'{path}'"


Expand All @@ -279,9 +284,15 @@ def normalize_local_file(file: str) -> str:
return normalize_path(file, is_local=True)


def split_path(path: str) -> Tuple[str, str]:
"""Split a file path into directory and file name."""
path = unwrap_single_quote(path)
return path.rsplit("/", maxsplit=1)


def unwrap_stage_location_single_quote(name: str) -> str:
new_name = unwrap_single_quote(name)
if new_name.startswith(STAGE_PREFIX):
if any(new_name.startswith(prefix) for prefix in SNOWFLAKE_PATH_PREFIXES):
return new_name
return f"{STAGE_PREFIX}{new_name}"

Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/file_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
normalize_local_file,
normalize_remote_file_or_dir,
result_set_to_rows,
split_path,
)


Expand Down Expand Up @@ -273,7 +274,7 @@ def put_stream(
)
raise ne.with_traceback(tb) from None
else:
stage_with_prefix, dest_filename = stage_location.rsplit("/", maxsplit=1)
stage_with_prefix, dest_filename = split_path(stage_location)
put_result = self._session._conn.upload_stream(
input_stream=input_stream,
stage_location=stage_with_prefix,
Expand Down Expand Up @@ -338,7 +339,7 @@ def get_stream(
else:
options = {"parallel": parallel}
tmp_dir = tempfile.gettempdir()
src_file_name = stage_location.rsplit("/", maxsplit=1)[1]
src_file_name = split_path(stage_location)[1]
local_file_name = os.path.join(tmp_dir, src_file_name)
plan = self._session._plan_builder.file_operation_plan(
"get",
Expand Down
67 changes: 67 additions & 0 deletions tests/unit/test_internal_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import pytest

from snowflake.snowpark._internal import utils


@pytest.mark.parametrize(
"path, expected_dir, expected_file",
[
("stage/", "stage", ""),
("stage/file.txt", "stage", "file.txt"),
("dir/subdir/file.txt", "dir/subdir", "file.txt"),
("@stage/dir/subdir/file.txt", "@stage/dir/subdir", "file.txt"),
("'@stage/dir/subdir/file.txt'", "@stage/dir/subdir", "file.txt"),
(
"snow://domain/test_entity/versions/test_version/file.txt",
"snow://domain/test_entity/versions/test_version",
"file.txt",
),
(
"'snow://domain/test_entity/versions/test_version/file.txt'",
"snow://domain/test_entity/versions/test_version",
"file.txt",
),
],
)
def test_split_path(path: str, expected_dir: str, expected_file: str) -> None:
dir, file = utils.split_path(path)
assert expected_dir == dir
assert expected_file == file


@pytest.mark.parametrize(
"path, is_local, expected",
[
("dir/file.txt", True, "'file://dir/file.txt'"),
("dir/subdir/file.txt", True, "'file://dir/subdir/file.txt'"),
("'dir/subdir/file.txt'", True, "'dir/subdir/file.txt'"),
("file://dir/subdir/file.txt", True, "'file://dir/subdir/file.txt'"),
("stage/", False, "'@stage/'"),
("stage/file.txt", False, "'@stage/file.txt'"),
("'stage/file.txt'", False, "'stage/file.txt'"),
(
"stage/'embedded_quote'/file.txt",
False,
"'@stage/\\'embedded_quote\\'/file.txt'",
),
("@stage/dir/subdir/file.txt", False, "'@stage/dir/subdir/file.txt'"),
("'@stage/dir/subdir/file.txt'", False, "'@stage/dir/subdir/file.txt'"),
(
"snow://domain/test_entity/versions/test_version/file.txt",
False,
"'snow://domain/test_entity/versions/test_version/file.txt'",
),
(
"'snow://domain/test_entity/versions/test_version/file.txt'",
False,
"'snow://domain/test_entity/versions/test_version/file.txt'",
),
],
)
def test_normalize_path(path: str, is_local: bool, expected: str) -> None:
actual = utils.normalize_path(path, is_local)
assert expected == actual

0 comments on commit 78c11ac

Please sign in to comment.