Skip to content

Commit

Permalink
reduce unnecessary pass (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Sep 17, 2024
1 parent 6efd4d0 commit 8eb516a
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 80 deletions.
5 changes: 2 additions & 3 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,12 +1136,11 @@ def run(self, data_recipe: DataRecipe) -> None:
if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO:
from lightning_sdk.lightning_cloud.openapi import V1DatasetType

data_type = V1DatasetType.CHUNKED if isinstance(data_recipe, DataChunkRecipe) else V1DatasetType.TRANSFORMED
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
dataset_type=V1DatasetType.CHUNKED
if isinstance(data_recipe, DataChunkRecipe)
else V1DatasetType.TRANSFORMED,
dataset_type=data_type,
empty=False,
size=result.size,
num_bytes=result.num_bytes,
Expand Down
2 changes: 0 additions & 2 deletions src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ def get_node_rank(self) -> int:
@abstractmethod
def remap_items(self, items: Any, num_workers: int) -> List[Any]:
"""This method is meant to remap the items provided by the users into items more adapted to be distributed."""
pass

@abstractmethod
def read(self, item: Any) -> Any:
"""Read the data associated to an item."""
pass


class ParquetReader(BaseReader):
Expand Down
120 changes: 57 additions & 63 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import shutil
import subprocess
from abc import ABC
from contextlib import suppress
from typing import Any, Dict, List, Optional
from urllib import parse

Expand Down Expand Up @@ -63,34 +64,32 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if os.path.exists(local_filepath):
return

try:
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
if self._s5cmd_available:
proc = subprocess.Popen(
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
with suppress(Timeout), FileLock(
local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0
):
if self._s5cmd_available:
proc = subprocess.Popen(
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
)
proc.wait()
else:
from boto3.s3.transfer import TransferConfig

extra_args: Dict[str, Any] = {}

# try:
# with FileLock(local_filepath + ".lock", timeout=1):
if not os.path.exists(local_filepath):
# Issue: https://github.com/boto/boto3/issues/3113
self._client.client.download_file(
obj.netloc,
obj.path.lstrip("/"),
local_filepath,
ExtraArgs=extra_args,
Config=TransferConfig(use_threads=False),
)
proc.wait()
else:
from boto3.s3.transfer import TransferConfig

extra_args: Dict[str, Any] = {}

# try:
# with FileLock(local_filepath + ".lock", timeout=1):
if not os.path.exists(local_filepath):
# Issue: https://github.com/boto/boto3/issues/3113
self._client.client.download_file(
obj.netloc,
obj.path.lstrip("/"),
local_filepath,
ExtraArgs=extra_args,
Config=TransferConfig(use_threads=False),
)
except Timeout:
# another process is responsible to download that file, continue
pass


class GCPDownloader(Downloader):
Expand All @@ -113,21 +112,19 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if os.path.exists(local_filepath):
return

try:
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
bucket_name = obj.netloc
key = obj.path
# Remove the leading "/":
if key[0] == "/":
key = key[1:]
with suppress(Timeout), FileLock(
local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0
):
bucket_name = obj.netloc
key = obj.path
# Remove the leading "/":
if key[0] == "/":
key = key[1:]

client = storage.Client(**self._storage_options)
bucket = client.bucket(bucket_name)
blob = bucket.blob(key)
blob.download_to_filename(local_filepath)
except Timeout:
# another process is responsible to download that file, continue
pass
client = storage.Client(**self._storage_options)
bucket = client.bucket(bucket_name)
blob = bucket.blob(key)
blob.download_to_filename(local_filepath)


class AzureDownloader(Downloader):
Expand All @@ -152,35 +149,32 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if os.path.exists(local_filepath):
return

try:
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
service = BlobServiceClient(**self._storage_options)
blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/"))
with open(local_filepath, "wb") as download_file:
blob_data = blob_client.download_blob()
blob_data.readinto(download_file)

except Timeout:
# another process is responsible to download that file, continue
pass
with suppress(Timeout), FileLock(
local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0
):
service = BlobServiceClient(**self._storage_options)
blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/"))
with open(local_filepath, "wb") as download_file:
blob_data = blob_client.download_blob()
blob_data.readinto(download_file)


class LocalDownloader(Downloader):
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")

try:
with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0):
if remote_filepath != local_filepath and not os.path.exists(local_filepath):
# make an atomic operation to be safe
temp_file_path = local_filepath + ".tmp"
shutil.copy(remote_filepath, temp_file_path)
os.rename(temp_file_path, local_filepath)
with contextlib.suppress(Exception):
os.remove(local_filepath + ".lock")
except Timeout:
pass
with suppress(Timeout), FileLock(
local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0
):
if remote_filepath == local_filepath or os.path.exists(local_filepath):
return
# make an atomic operation to be safe
temp_file_path = local_filepath + ".tmp"
shutil.copy(remote_filepath, temp_file_path)
os.rename(temp_file_path, local_filepath)
with contextlib.suppress(Exception):
os.remove(local_filepath + ".lock")


class LocalDownloaderWithCache(LocalDownloader):
Expand Down
4 changes: 0 additions & 4 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,10 @@ def generate_intervals(self) -> List[Interval]:
region_of_interest: indicates the indexes a chunk our StreamingDataset is allowed to read.
"""
pass

@abstractmethod
def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
"""Logic to load the chunk in background to gain some time."""
pass

@abstractmethod
def load_item_from_chunk(
Expand All @@ -93,12 +91,10 @@ def load_item_from_chunk(
chunk_bytes: int,
) -> Any:
"""Returns an item loaded from a chunk."""
pass

@abstractmethod
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
"""Delete a chunk from the local filesystem."""
pass

@abstractmethod
def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any:
Expand Down
4 changes: 0 additions & 4 deletions src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from threading import Event, Thread
from typing import Any, Dict, List, Optional, Tuple, Union

from litdata.constants import _TORCH_GREATER_EQUAL_2_1_0
from litdata.streaming.config import ChunksConfig, Interval
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader
from litdata.streaming.sampler import ChunkedIndex
Expand All @@ -29,9 +28,6 @@

warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*")

if _TORCH_GREATER_EQUAL_2_1_0:
pass


logger = Logger(__name__)

Expand Down
7 changes: 3 additions & 4 deletions src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tempfile
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import suppress
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -128,11 +129,9 @@ def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]:
from torchvision.transforms.functional import pil_to_tensor

array = torch.frombuffer(data, dtype=torch.uint8)
try:
# Note: Some datasets like Imagenet contains some PNG images with JPEG extension, so we fallback to PIL
with suppress(RuntimeError):
return decode_jpeg(array)
except RuntimeError:
# Note: Some datasets like Imagenet contains some PNG images with JPEG extension, so we fallback to PIL
pass

img = PILSerializer.deserialize(data)
if _TORCH_VISION_AVAILABLE:
Expand Down

0 comments on commit 8eb516a

Please sign in to comment.