Skip to content

Commit

Permalink
Add utility to merge datasets together (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Jun 27, 2024
1 parent 9309fed commit f2c5a7b
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/litdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from litdata.__about__ import * # noqa: F403
from litdata.imports import RequirementCache
from litdata.processing.functions import map, optimize, walk
from litdata.processing.functions import map, merge_datasets, optimize, walk
from litdata.streaming.combined import CombinedStreamingDataset
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.dataset import StreamingDataset
Expand All @@ -29,6 +29,7 @@
"optimize",
"walk",
"train_test_split",
"merge_datasets",
]
if RequirementCache("lightning_sdk"):
from lightning_sdk import Machine # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_TQDM_AVAILABLE = RequirementCache("tqdm")


# DON'T CHANGE ORDER
Expand Down
4 changes: 1 addition & 3 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_AVAILABLE,
_TQDM_AVAILABLE,
)
from litdata.imports import RequirementCache
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
from litdata.processing.utilities import _create_dataset
from litdata.streaming import Cache
Expand All @@ -52,8 +52,6 @@
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.packing import _pack_greedily

_TQDM_AVAILABLE = RequirementCache("tqdm")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

Expand Down
126 changes: 125 additions & 1 deletion src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

import concurrent.futures
import inspect
import json
import os
import shutil
import tempfile
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from pathlib import Path
Expand All @@ -23,14 +27,15 @@

import torch

from litdata.constants import _IS_IN_STUDIO
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _TQDM_AVAILABLE
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from litdata.processing.readers import BaseReader
from litdata.processing.utilities import (
extract_rank_and_index_from_filename,
optimize_dns_context,
read_index_file_content,
)
from litdata.streaming.client import S3Client
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.resolver import (
Dir,
Expand All @@ -41,6 +46,13 @@
)
from litdata.utilities._pytree import tree_flatten

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator


def _is_remote_file(path: str) -> bool:
obj = parse.urlparse(path)
Expand Down Expand Up @@ -470,3 +482,115 @@ def __iter__(self) -> Any:
future = executor.submit(_listdir, folder)
self.futures.append(future)
return


@dataclass
class CopyInfo:
input_dir: Dir
old_filename: str
new_filename: str


def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
"""The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized
dataset.
Arguments:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.
"""
if len(input_dirs) == 0:
raise ValueError("The input directories needs to be defined.")

if len(input_dirs) == 1:
raise ValueError("There should be more than 1 input directory")

resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs]
resolved_output_dir = _resolve_dir(output_dir)

if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs):
raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.")

input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs]

if any(file_content is None for file_content in input_dirs_file_content):
raise ValueError("One of the provided input_dir doesn't have an index file.")

output_dir_file_content = read_index_file_content(resolved_output_dir)

if output_dir_file_content is not None:
raise ValueError("The output_dir already contains an optimized dataset")

assert input_dirs_file_content

for input_dir_file_content in input_dirs_file_content[1:]:
if input_dirs_file_content[0]["config"]["data_format"] != input_dir_file_content["config"]["data_format"]: # type: ignore
raise ValueError("Your are trying to merge datasets with different data formats")

if input_dirs_file_content[0]["config"]["compression"] != input_dir_file_content["config"]["compression"]: # type: ignore
raise ValueError("Your are trying to merge datasets with different compression configuration.")

chunks = []
copy_infos: List[CopyInfo] = []
counter = 0
for input_dir, input_dir_file_content in zip(resolved_input_dirs, input_dirs_file_content):
for chunk in input_dir_file_content["chunks"]: # type: ignore
assert isinstance(chunk, dict)
old_filename = chunk["filename"]
new_filename = f"chunk-0-{counter}.bin"
copy_infos.append(CopyInfo(input_dir=input_dir, old_filename=old_filename, new_filename=new_filename))
chunk["filename"] = new_filename
chunks.append(chunk)
counter += 1

index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore

for copy_info in _tqdm(copy_infos):
_apply_copy(copy_info, resolved_output_dir)

_save_index(index_json, resolved_output_dir)


def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None:
if output_dir.url is None and copy_info.input_dir.url is None:
assert copy_info.input_dir.path
assert output_dir.path
input_filepath = os.path.join(copy_info.input_dir.path, copy_info.old_filename)
output_filepath = os.path.join(output_dir.path, copy_info.new_filename)
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
shutil.copyfile(input_filepath, output_filepath)

elif output_dir.url and copy_info.input_dir.url:
input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename))
output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename))

s3 = S3Client()
s3.client.copy(
{"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")},
output_obj.netloc,
output_obj.path.lstrip("/"),
)
else:
raise NotImplementedError


def _save_index(index_json: Dict, output_dir: Dir) -> None:
if output_dir.url is None:
assert output_dir.path
with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f:
json.dump(index_json, f)
else:
with tempfile.NamedTemporaryFile("w") as f:
json.dump(index_json, f)

f.flush()

obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME))

s3 = S3Client()
s3.client.upload_file(
f.name,
obj.netloc,
obj.path.lstrip("/"),
)
2 changes: 1 addition & 1 deletion src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from abc import ABC, abstractmethod
from typing import Any, List

from litdata.constants import _TQDM_AVAILABLE
from litdata.imports import RequirementCache
from litdata.streaming.dataloader import StreamingDataLoader

_PYARROW_AVAILABLE = RequirementCache("pyarrow")
_TQDM_AVAILABLE = RequirementCache("tqdm")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
Expand Down
36 changes: 35 additions & 1 deletion tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from unittest import mock

import pytest
from litdata import StreamingDataset, optimize, walk
from litdata import StreamingDataset, merge_datasets, optimize, walk
from litdata.processing.functions import _get_input_dir, _resolve_dir
from litdata.streaming.cache import Cache


@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
Expand Down Expand Up @@ -154,3 +155,36 @@ def test_optimize_append_overwrite(tmpdir):

assert len(ds) == 5
assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)]


def test_merge_datasets(tmpdir):
folder_1 = os.path.join(tmpdir, "folder_1")
folder_2 = os.path.join(tmpdir, "folder_2")
folder_3 = os.path.join(tmpdir, "folder_3")

os.makedirs(folder_1, exist_ok=True)
os.makedirs(folder_2, exist_ok=True)

cache_1 = Cache(input_dir=folder_1, chunk_bytes="64MB")
for i in range(10):
cache_1[i] = i

cache_1.done()
cache_1.merge()

cache_2 = Cache(input_dir=folder_2, chunk_bytes="64MB")
for i in range(10, 20):
cache_2[i] = i

cache_2.done()
cache_2.merge()

merge_datasets(
input_dirs=[folder_1, folder_2],
output_dir=folder_3,
)

ds = StreamingDataset(input_dir=folder_3)

assert len(ds) == 20
assert ds[:] == list(range(20))
4 changes: 3 additions & 1 deletion tests/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,7 @@ def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch):
assert thread._pre_download_counter <= 2

assert len(os.listdir(cache_dir)) == 9
assert thread._has_exited

thread.join()
sleep(0.1)
assert thread._has_exited

0 comments on commit f2c5a7b

Please sign in to comment.