Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1418523: concurrent file operations #2288

23 changes: 22 additions & 1 deletion tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._connection import MockServerConnection
from tests.parameters import CONNECTION_PARAMETERS
from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci
from tests.utils import (
TEST_SCHEMA,
TestFiles,
Utils,
running_on_jenkins,
running_on_public_ci,
)


def print_help() -> None:
Expand Down Expand Up @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None:
)
yield temp_schema_name
cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}")


@pytest.fixture(scope="module")
def temp_stage(session, resources_path, local_testing_mode):
tmp_stage_name = Utils.random_stage_name()
test_files = TestFiles(resources_path)

if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name, is_temporary=True)
Utils.upload_to_stage(
session, tmp_stage_name, test_files.test_file_parquet, compress=False
)
yield tmp_stage_name
if not local_testing_mode:
Utils.drop_stage(session, tmp_stage_name)
17 changes: 1 addition & 16 deletions tests/integ/scala/test_file_operation_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SnowparkSQLException,
SnowparkUploadFileException,
)
from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils
from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils


def random_alphanumeric_name():
Expand Down Expand Up @@ -74,21 +74,6 @@ def path4(temp_source_directory):
yield filename


@pytest.fixture(scope="module")
def temp_stage(session, resources_path, local_testing_mode):
tmp_stage_name = Utils.random_stage_name()
test_files = TestFiles(resources_path)

if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name, is_temporary=True)
Utils.upload_to_stage(
session, tmp_stage_name, test_files.test_file_parquet, compress=False
)
yield tmp_stage_name
if not local_testing_mode:
Utils.drop_stage(session, tmp_stage_name)


def test_put_with_one_file(
session, temp_stage, path1, path2, path3, local_testing_mode
):
Expand Down
75 changes: 74 additions & 1 deletion tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import hashlib
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from unittest.mock import patch

import pytest

from snowflake.snowpark.functions import lit
from snowflake.snowpark.row import Row
from tests.utils import IS_IN_STORED_PROC, Utils
from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils


def test_concurrent_select_queries(session):
Expand Down Expand Up @@ -122,3 +125,73 @@ def test_action_ids_are_unique(session):
action_ids.add(future.result())

assert len(action_ids) == 10


@pytest.mark.parametrize("use_stream", [True, False])
def test_file_io(session, resources_path, temp_stage, use_stream):
stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}"
stage_with_prefix = f"@{temp_stage}/{stage_prefix}/"
test_files = TestFiles(resources_path)

resources_files = [
test_files.test_file_csv,
test_files.test_file2_csv,
test_files.test_file_json,
test_files.test_file_csv_header,
test_files.test_file_csv_colon,
test_files.test_file_csv_quotes,
test_files.test_file_csv_special_format,
test_files.test_file_json_special_format,
test_files.test_file_csv_quotes_special,
test_files.test_concat_file1_csv,
test_files.test_concat_file2_csv,
]

def get_file_hash(fd):
return hashlib.md5(fd.read()).hexdigest()

def put_and_get_file(upload_file_path, download_dir):
if use_stream:
with open(upload_file_path, "rb") as fd:
results = session.file.put_stream(
fd, stage_with_prefix, auto_compress=False, overwrite=False
)
else:
results = session.file.put(
upload_file_path,
stage_with_prefix,
auto_compress=False,
overwrite=False,
)
# assert file is uploaded successfully
assert len(results) == 1
assert results[0].status == "UPLOADED"

stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}"
if use_stream:
fd = session.file.get_stream(stage_file_name, download_dir)
with open(upload_file_path, "rb") as upload_fd:
assert get_file_hash(upload_fd) == get_file_hash(fd)

else:
results = session.file.get(stage_file_name, download_dir)
# assert file is downloaded successfully
assert len(results) == 1
assert results[0].status == "DOWNLOADED"
download_file_path = results[0].file
# assert two files are identical
with open(upload_file_path, "rb") as upload_fd, open(
download_file_path, "rb"
) as download_fd:
assert get_file_hash(upload_fd) == get_file_hash(download_fd)

with tempfile.TemporaryDirectory() as download_dir:
with ThreadPoolExecutor(max_workers=10) as executor:
for file_path in resources_files:
executor.submit(put_and_get_file, file_path, download_dir)
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved

if not use_stream:
# assert all files are downloaded
assert set(os.listdir(download_dir)) == {
os.path.basename(file_path) for file_path in resources_files
}
Loading