Skip to content

Commit

Permalink
Fix: Prevent multiple processes from copying the same file when using… (
Browse files Browse the repository at this point in the history
#353)

Co-authored-by: tchaton <[email protected]>
  • Loading branch information
dallmann-uniwue and tchaton authored Sep 3, 2024
1 parent 8382067 commit bc8c0fd
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 7 deletions.
14 changes: 12 additions & 2 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import os
import shutil
import subprocess
Expand Down Expand Up @@ -169,8 +170,17 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")

if remote_filepath != local_filepath and not os.path.exists(local_filepath):
shutil.copy(remote_filepath, local_filepath)
try:
with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0):
if remote_filepath != local_filepath and not os.path.exists(local_filepath):
# make an atomic operation to be safe
temp_file_path = local_filepath + ".tmp"
shutil.copy(remote_filepath, temp_file_path)
os.rename(temp_file_path, local_filepath)
with contextlib.suppress(Exception):
os.remove(local_filepath + ".lock")
except Timeout:
pass


class LocalDownloaderWithCache(LocalDownloader):
Expand Down
6 changes: 4 additions & 2 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from litdata.utilities.shuffle import _associate_chunks_and_intervals_to_workers
from torch.utils.data import DataLoader

from tests.streaming.utils import filter_lock_files


def seed_everything(random_seed):
random.seed(random_seed)
Expand Down Expand Up @@ -861,13 +863,13 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir):
for batch in dataloader:
batches_epoch_1.append(batch)

assert len(os.listdir(cache_dir)) == 51
assert len(filter_lock_files(os.listdir(cache_dir))) == 51

batches_epoch_2 = []
for batch in dataloader:
batches_epoch_2.append(batch)

assert len(os.listdir(cache_dir)) == 51
assert len(filter_lock_files(os.listdir(cache_dir))) == 51
assert not all(torch.equal(b1, b2) for b1, b2 in zip(batches_epoch_1, batches_epoch_2))


Expand Down
4 changes: 4 additions & 0 deletions tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def test_download_with_cache(tmpdir, monkeypatch):
try:
local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, [])
shutil_mock = MagicMock()
os_mock = MagicMock()
monkeypatch.setattr(shutil, "copy", shutil_mock)
monkeypatch.setattr(os, "rename", os_mock)

local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt"))
shutil_mock.assert_called()
os_mock.assert_called()
finally:
os.remove("a.txt")
8 changes: 5 additions & 3 deletions tests/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from litdata.streaming.resolver import Dir
from litdata.utilities.env import _DistributedEnv

from tests.streaming.utils import filter_lock_files


def test_reader_chunk_removal(tmpdir):
cache_dir = os.path.join(tmpdir, "cache_dir")
Expand All @@ -32,19 +34,19 @@ def test_reader_chunk_removal(tmpdir):
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert len(os.listdir(cache_dir)) == 14
assert len(filter_lock_files(os.listdir(cache_dir))) == 14

cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=2800)

shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok=True)

for i in range(25):
assert len(os.listdir(cache_dir)) <= 3
assert len(filter_lock_files(os.listdir(cache_dir))) <= 3
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert len(os.listdir(cache_dir)) in [2, 3]
assert len(filter_lock_files(os.listdir(cache_dir))) in [2, 3]


def test_get_folder_size(tmpdir):
Expand Down
2 changes: 2 additions & 0 deletions tests/streaming/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def filter_lock_files(files):
return [f for f in files if not f.endswith(".lock")]

0 comments on commit bc8c0fd

Please sign in to comment.