Skip to content

Commit

Permalink
add mock tests to test_s3_downloader_with_s5cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Oct 14, 2024
1 parent 1783215 commit 8deb5d0
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

from litdata.streaming.downloader import (
AzureDownloader,
Expand All @@ -21,6 +21,62 @@ def test_s3_downloader_fast(tmpdir, monkeypatch):
popen_mock.wait.assert_called()


@patch("os.system")
@patch("subprocess.Popen")
def test_s3_downloader_with_s5cmd_no_storage_options(popen_mock, system_mock, tmpdir):
system_mock.return_value = 0 # Simulates s5cmd being available
process_mock = MagicMock()
popen_mock.return_value = process_mock

# Initialize the S3Downloader without storage options
downloader = S3Downloader("s3://test_bucket", str(tmpdir), [])

# Action: Call the download_file method
remote_filepath = "s3://test_bucket/sample_file.txt"
local_filepath = os.path.join(tmpdir, "sample_file.txt")
downloader.download_file(remote_filepath, local_filepath)

# Assertion: Verify subprocess.Popen was called with correct arguments and no env variables
popen_mock.assert_called_once_with( # noqa: S604
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
env=None,
)
process_mock.wait.assert_called_once()


@patch("os.system")
@patch("subprocess.Popen")
def test_s3_downloader_with_s5cmd_with_storage_options(popen_mock, system_mock, tmpdir):
system_mock.return_value = 0 # Simulates s5cmd being available
process_mock = MagicMock()
popen_mock.return_value = process_mock

storage_options = {"AWS_ACCESS_KEY_ID": "dummy_key", "AWS_SECRET_ACCESS_KEY": "dummy_secret"}

# Initialize the S3Downloader with storage options
downloader = S3Downloader("s3://test_bucket", str(tmpdir), [], storage_options)

# Action: Call the download_file method
remote_filepath = "s3://test_bucket/sample_file.txt"
local_filepath = os.path.join(tmpdir, "sample_file.txt")
downloader.download_file(remote_filepath, local_filepath)

# Create expected environment variables by merging the current env with storage_options
expected_env = os.environ.copy()
expected_env.update(storage_options)

# Assertion: Verify subprocess.Popen was called with the correct arguments and environment variables
popen_mock.assert_called_once_with( # noqa: S604
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
env=expected_env,
)
process_mock.wait.assert_called_once()


@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True)
def test_gcp_downloader(tmpdir, monkeypatch, google_mock):
# Create mock objects
Expand Down

0 comments on commit 8deb5d0

Please sign in to comment.