Skip to content

Commit

Permalink
SNOW-1418523: concurrent file operations (#2288)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Sep 25, 2024
1 parent 5f140ab commit 0624824
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 18 deletions.
23 changes: 22 additions & 1 deletion tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._connection import MockServerConnection
from tests.parameters import CONNECTION_PARAMETERS
from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci
from tests.utils import (
TEST_SCHEMA,
TestFiles,
Utils,
running_on_jenkins,
running_on_public_ci,
)


def print_help() -> None:
Expand Down Expand Up @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None:
)
yield temp_schema_name
cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}")


@pytest.fixture(scope="module")
def temp_stage(session, resources_path, local_testing_mode):
tmp_stage_name = Utils.random_stage_name()
test_files = TestFiles(resources_path)

if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name, is_temporary=True)
Utils.upload_to_stage(
session, tmp_stage_name, test_files.test_file_parquet, compress=False
)
yield tmp_stage_name
if not local_testing_mode:
Utils.drop_stage(session, tmp_stage_name)
17 changes: 1 addition & 16 deletions tests/integ/scala/test_file_operation_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SnowparkSQLException,
SnowparkUploadFileException,
)
from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils
from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils


def random_alphanumeric_name():
Expand Down Expand Up @@ -74,21 +74,6 @@ def path4(temp_source_directory):
yield filename


@pytest.fixture(scope="module")
def temp_stage(session, resources_path, local_testing_mode):
tmp_stage_name = Utils.random_stage_name()
test_files = TestFiles(resources_path)

if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name, is_temporary=True)
Utils.upload_to_stage(
session, tmp_stage_name, test_files.test_file_parquet, compress=False
)
yield tmp_stage_name
if not local_testing_mode:
Utils.drop_stage(session, tmp_stage_name)


def test_put_with_one_file(
session, temp_stage, path1, path2, path3, local_testing_mode
):
Expand Down
75 changes: 74 additions & 1 deletion tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import hashlib
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from unittest.mock import patch

import pytest

from snowflake.snowpark.functions import lit
from snowflake.snowpark.row import Row
from tests.utils import IS_IN_STORED_PROC, Utils
from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils


def test_concurrent_select_queries(session):
Expand Down Expand Up @@ -122,3 +125,73 @@ def test_action_ids_are_unique(session):
action_ids.add(future.result())

assert len(action_ids) == 10


@pytest.mark.parametrize("use_stream", [True, False])
def test_file_io(session, resources_path, temp_stage, use_stream):
stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}"
stage_with_prefix = f"@{temp_stage}/{stage_prefix}/"
test_files = TestFiles(resources_path)

resources_files = [
test_files.test_file_csv,
test_files.test_file2_csv,
test_files.test_file_json,
test_files.test_file_csv_header,
test_files.test_file_csv_colon,
test_files.test_file_csv_quotes,
test_files.test_file_csv_special_format,
test_files.test_file_json_special_format,
test_files.test_file_csv_quotes_special,
test_files.test_concat_file1_csv,
test_files.test_concat_file2_csv,
]

def get_file_hash(fd):
return hashlib.md5(fd.read()).hexdigest()

def put_and_get_file(upload_file_path, download_dir):
if use_stream:
with open(upload_file_path, "rb") as fd:
results = session.file.put_stream(
fd, stage_with_prefix, auto_compress=False, overwrite=False
)
else:
results = session.file.put(
upload_file_path,
stage_with_prefix,
auto_compress=False,
overwrite=False,
)
# assert file is uploaded successfully
assert len(results) == 1
assert results[0].status == "UPLOADED"

stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}"
if use_stream:
fd = session.file.get_stream(stage_file_name, download_dir)
with open(upload_file_path, "rb") as upload_fd:
assert get_file_hash(upload_fd) == get_file_hash(fd)

else:
results = session.file.get(stage_file_name, download_dir)
# assert file is downloaded successfully
assert len(results) == 1
assert results[0].status == "DOWNLOADED"
download_file_path = results[0].file
# assert two files are identical
with open(upload_file_path, "rb") as upload_fd, open(
download_file_path, "rb"
) as download_fd:
assert get_file_hash(upload_fd) == get_file_hash(download_fd)

with tempfile.TemporaryDirectory() as download_dir:
with ThreadPoolExecutor(max_workers=10) as executor:
for file_path in resources_files:
executor.submit(put_and_get_file, file_path, download_dir)

if not use_stream:
# assert all files are downloaded
assert set(os.listdir(download_dir)) == {
os.path.basename(file_path) for file_path in resources_files
}

0 comments on commit 0624824

Please sign in to comment.