Skip to content

Commit

Permalink
all tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
deependujha committed Sep 4, 2024
1 parent e712327 commit e118ba9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 107 deletions.
2 changes: 1 addition & 1 deletion src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool
object_list = list_directory(output_dir.url)
except FileNotFoundError:
return

print(f"{object_list=}")
# We aren't alloweing to add more data
if object_list is not None and len(object_list) > 0:
raise RuntimeError(
Expand Down
32 changes: 12 additions & 20 deletions tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import torch
from lightning_utilities.core.imports import RequirementCache

from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE
from litdata.processing import data_processor as data_processor_module
from litdata.processing import functions
Expand Down Expand Up @@ -109,8 +110,6 @@ def fn(*_, **__):

remove_queue = mock.MagicMock()

s3_client = mock.MagicMock()

called = False

def copy_file(local_filepath, *args):
Expand All @@ -120,9 +119,7 @@ def copy_file(local_filepath, *args):

copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath)))

s3_client.client.upload_file = copy_file

monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client))
monkeypatch.setattr(data_processor_module, "upload_file_or_directory", copy_file)

assert os.listdir(remote_output_dir) == []

Expand Down Expand Up @@ -217,32 +214,28 @@ def test_wait_for_disk_usage_higher_than_threshold():


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_wait_for_file_to_exist():
import botocore

s3 = mock.MagicMock()
obj = mock.MagicMock()
def test_wait_for_file_to_exist(monkeypatch):
raise_error = [True, True, False]

def fn(*_, **__):
value = raise_error.pop(0)
if value:
raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception
return

s3.client.head_object = fn
monkeypatch.setattr(data_processor_module, "does_file_exist", fn)

_wait_for_file_to_exist(s3, obj, sleep_time=0.01)
_wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01)

assert len(raise_error) == 0

def fn(*_, **__):
raise ValueError("HERE")

s3.client.head_object = fn
monkeypatch.setattr(data_processor_module, "does_file_exist", fn)

with pytest.raises(ValueError, match="HERE"):
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)
_wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01)


def test_cache_dir_cleanup(tmpdir, monkeypatch):
Expand Down Expand Up @@ -1025,11 +1018,10 @@ def test_data_processing_map_non_absolute_path(monkeypatch, tmpdir):

@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_map_error_when_not_empty(monkeypatch):
boto3 = mock.MagicMock()
client_s3_mock = mock.MagicMock()
client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []}
boto3.client.return_value = client_s3_mock
monkeypatch.setattr(resolver, "boto3", boto3)
def mock_list_directory(*args, **kwargs):
return ["a.txt", "b.txt"]

monkeypatch.setattr(resolver, "list_directory", mock_list_directory)

with pytest.raises(RuntimeError, match="data and datasets are meant to be immutable"):
map(
Expand Down
65 changes: 7 additions & 58 deletions tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,14 @@
subprocess,
)

# def test_s3_downloader_fast(tmpdir, monkeypatch):
# monkeypatch.setattr(os, "system", MagicMock(return_value=0))
# popen_mock = MagicMock()
# monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock))
# downloader = S3Downloader(tmpdir, tmpdir, [])
# downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt"))
# popen_mock.wait.assert_called()

def test_s3_downloader_fast(tmpdir, monkeypatch):
monkeypatch.setattr(os, "system", MagicMock(return_value=0))
popen_mock = MagicMock()
monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock))
downloader = S3Downloader(tmpdir, tmpdir, [])
downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt"))
popen_mock.wait.assert_called()


@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True)
def test_gcp_downloader(tmpdir, monkeypatch, google_mock):
# Create mock objects
mock_client = MagicMock()
mock_bucket = MagicMock()
mock_blob = MagicMock()
mock_blob.download_to_filename = MagicMock()

# Patch the storage client to return the mock client
google_mock.cloud.storage.Client = MagicMock(return_value=mock_client)

# Configure the mock client to return the mock bucket and blob
mock_client.bucket = MagicMock(return_value=mock_bucket)
mock_bucket.blob = MagicMock(return_value=mock_blob)

# Initialize the downloader
storage_options = {"project": "DUMMY_PROJECT"}
downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options)
local_filepath = os.path.join(tmpdir, "a.txt")
downloader.download_file("gs://random_bucket/a.txt", local_filepath)

# Assert that the correct methods were called
google_mock.cloud.storage.Client.assert_called_with(**storage_options)
mock_client.bucket.assert_called_with("random_bucket")
mock_bucket.blob.assert_called_with("a.txt")
mock_blob.download_to_filename.assert_called_with(local_filepath)


@mock.patch("litdata.streaming.downloader._AZURE_STORAGE_AVAILABLE", True)
def test_azure_downloader(tmpdir, monkeypatch, azure_mock):
mock_blob = MagicMock()
mock_blob_data = MagicMock()
mock_blob.download_blob.return_value = mock_blob_data
service_mock = MagicMock()
service_mock.get_blob_client.return_value = mock_blob

azure_mock.storage.blob.BlobServiceClient = MagicMock(return_value=service_mock)

# Initialize the downloader
storage_options = {"project": "DUMMY_PROJECT"}
downloader = AzureDownloader("azure://random_bucket", tmpdir, [], storage_options)
local_filepath = os.path.join(tmpdir, "a.txt")
downloader.download_file("azure://random_bucket/a.txt", local_filepath)

# Assert that the correct methods were called
azure_mock.storage.blob.BlobServiceClient.assert_called_with(**storage_options)
service_mock.get_blob_client.assert_called_with(container="random_bucket", blob="a.txt")
mock_blob.download_blob.assert_called()
mock_blob_data.readinto.assert_called()


def test_download_with_cache(tmpdir, monkeypatch):
Expand Down
56 changes: 28 additions & 28 deletions tests/streaming/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,52 +301,52 @@ def print_fn(msg, file=None):


def test_assert_dir_is_empty(monkeypatch):
boto3 = mock.MagicMock()
client_s3_mock = mock.MagicMock()
client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []}
boto3.client.return_value = client_s3_mock
resolver.boto3 = boto3
def mock_list_directory(*args, **kwargs):
return ["a.txt", "b.txt"]
def mock_empty_list_directory(*args, **kwargs):
return []
monkeypatch.setattr(resolver, "list_directory", mock_list_directory)

with pytest.raises(RuntimeError, match="The provided output_dir"):
resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://"))

client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []}
boto3.client.return_value = client_s3_mock
resolver.boto3 = boto3
monkeypatch.setattr(resolver, "list_directory", mock_empty_list_directory)

resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://"))


def test_assert_dir_has_index_file(monkeypatch):
boto3 = mock.MagicMock()
client_s3_mock = mock.MagicMock()
client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []}
boto3.client.return_value = client_s3_mock
resolver.boto3 = boto3
def mock_list_directory_0(*args, **kwargs):
return []

with pytest.raises(RuntimeError, match="The provided output_dir"):
resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"))
def mock_list_directory_1(*args, **kwargs):
return ['a.txt', 'b.txt']

client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []}
boto3.client.return_value = client_s3_mock
resolver.boto3 = boto3
def mock_list_directory_2(*args, **kwargs):
return ["index.json"]

resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"))
def mock_does_file_exist_1(*args, **kwargs):
raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception

client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []}
def mock_does_file_exist_2(*args, **kwargs):
return True

def head_object(*args, **kwargs):
import botocore
def mock_remove_file_or_directory(*args, **kwargs):
return

raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")

client_s3_mock.head_object = head_object
boto3.client.return_value = client_s3_mock
resolver.boto3 = boto3
monkeypatch.setattr(resolver, "list_directory", mock_list_directory_0)
monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_1)
monkeypatch.setattr(resolver, "remove_file_or_directory", mock_remove_file_or_directory)

resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"))

boto3.resource.assert_called()
monkeypatch.setattr(resolver, "list_directory", mock_list_directory_2)
monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_2)

with pytest.raises(RuntimeError, match="The provided output_dir"):
resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"))

resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"), mode='overwrite')


def test_resolve_dir_absolute(tmp_path, monkeypatch):
Expand Down

0 comments on commit e118ba9

Please sign in to comment.