From bc8c0fd3dcd90f7c79917e965872dcef1d0f5390 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 3 Sep 2024 14:29:07 +0200 Subject: [PATCH] =?UTF-8?q?Fix:=20Prevent=20multiple=20processes=20from=20?= =?UTF-8?q?copying=20the=20same=20file=20when=20using=E2=80=A6=20(#353)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: tchaton --- src/litdata/streaming/downloader.py | 14 ++++++++++++-- tests/streaming/test_dataset.py | 6 ++++-- tests/streaming/test_downloader.py | 4 ++++ tests/streaming/test_reader.py | 8 +++++--- tests/streaming/utils.py | 2 ++ 5 files changed, 27 insertions(+), 7 deletions(-) create mode 100644 tests/streaming/utils.py diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index ffdbe193..fa4e5fe3 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os import shutil import subprocess @@ -169,8 +170,17 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if not os.path.exists(remote_filepath): raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") - if remote_filepath != local_filepath and not os.path.exists(local_filepath): - shutil.copy(remote_filepath, local_filepath) + try: + with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0): + if remote_filepath != local_filepath and not os.path.exists(local_filepath): + # make an atomic operation to be safe + temp_file_path = local_filepath + ".tmp" + shutil.copy(remote_filepath, temp_file_path) + os.rename(temp_file_path, local_filepath) + with contextlib.suppress(Exception): + os.remove(local_filepath + ".lock") + except Timeout: + pass class LocalDownloaderWithCache(LocalDownloader): diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index d3d9d695..78cb0eaa 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -45,6 +45,8 @@ from litdata.utilities.shuffle import _associate_chunks_and_intervals_to_workers from torch.utils.data import DataLoader +from tests.streaming.utils import filter_lock_files + def seed_everything(random_seed): random.seed(random_seed) @@ -861,13 +863,13 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): for batch in dataloader: batches_epoch_1.append(batch) - assert len(os.listdir(cache_dir)) == 51 + assert len(filter_lock_files(os.listdir(cache_dir))) == 51 batches_epoch_2 = [] for batch in dataloader: batches_epoch_2.append(batch) - assert len(os.listdir(cache_dir)) == 51 + assert len(filter_lock_files(os.listdir(cache_dir))) == 51 assert not all(torch.equal(b1, b2) for b1, b2 in zip(batches_epoch_1, batches_epoch_2)) diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 396dc5d4..7c79afe5 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -80,8 +80,12 @@ def test_download_with_cache(tmpdir, monkeypatch): try: local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, []) shutil_mock = MagicMock() + os_mock = MagicMock() monkeypatch.setattr(shutil, "copy", shutil_mock) + monkeypatch.setattr(os, "rename", os_mock) + local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt")) shutil_mock.assert_called() + os_mock.assert_called() finally: os.remove("a.txt") diff --git a/tests/streaming/test_reader.py b/tests/streaming/test_reader.py index 9df6aa51..34a16331 100644 --- a/tests/streaming/test_reader.py +++ b/tests/streaming/test_reader.py @@ -11,6 +11,8 @@ from litdata.streaming.resolver import Dir from litdata.utilities.env import _DistributedEnv +from tests.streaming.utils import filter_lock_files + def test_reader_chunk_removal(tmpdir): cache_dir = os.path.join(tmpdir, "cache_dir") @@ -32,7 +34,7 @@ def test_reader_chunk_removal(tmpdir): index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24) assert cache[index] == i - assert len(os.listdir(cache_dir)) == 14 + assert len(filter_lock_files(os.listdir(cache_dir))) == 14 cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=2800) @@ -40,11 +42,11 @@ def test_reader_chunk_removal(tmpdir): os.makedirs(cache_dir, exist_ok=True) for i in range(25): - assert len(os.listdir(cache_dir)) <= 3 + assert len(filter_lock_files(os.listdir(cache_dir))) <= 3 index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24) assert cache[index] == i - assert len(os.listdir(cache_dir)) in [2, 3] + assert len(filter_lock_files(os.listdir(cache_dir))) in [2, 3] def test_get_folder_size(tmpdir): diff --git a/tests/streaming/utils.py b/tests/streaming/utils.py new file mode 100644 index 00000000..daf713f6 --- /dev/null +++ b/tests/streaming/utils.py @@ -0,0 +1,2 @@ +def filter_lock_files(files): + return [f for f in files if not f.endswith(".lock")]