Skip to content

Commit

Permalink
Merge branch 'main' into fix_uneven_number_of_batches2
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jul 17, 2024
2 parents 92d52d3 + 5d6b8f9 commit ec9db93
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ lint.extend-select = [
"SIM", # see: https://pypi.org/project/flake8-simplify
"RET", # see: https://pypi.org/project/flake8-return
"PT", # see: https://pypi.org/project/flake8-pytest-style
"NPY201", # see: https://docs.astral.sh/ruff/rules/numpy2-deprecation
"RUF100" # yesqa
]
lint.ignore = [
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch
filelock
numpy < 2.0.0
numpy
boto3
requests
24 changes: 22 additions & 2 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import os
from pathlib import Path
from typing import Dict

import numpy as np
import torch
Expand Down Expand Up @@ -59,8 +60,27 @@
19: torch.bool,
}

_NUMPY_SCTYPES = [v for values in np.sctypes.values() for v in values]
_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)}
_NUMPY_SCTYPES = [ # All NumPy scalar types from np.core.sctypes.values()
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.float16,
np.float32,
np.float64,
np.complex64,
np.complex128,
bool,
object,
bytes,
str,
np.void,
]
_NUMPY_DTYPES_MAPPING: Dict[int, np.dtype] = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)}

_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
_IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None))
Expand Down
7 changes: 5 additions & 2 deletions tests/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,13 @@ def test_assert_no_header_tensor_serializer():

def test_assert_no_header_numpy_serializer():
serializer = NoHeaderNumpySerializer()
t = np.ones((10,))
t = np.ones((10,), dtype=np.float64)
assert serializer.can_serialize(t)
data, name = serializer.serialize(t)
assert name == "no_header_numpy:10"
try:
assert name == "no_header_numpy:10"
except AssertionError as e: # debug what np.core.sctypes looks like on Windows
raise ValueError(np.core.sctypes) from e
assert serializer._dtype is None
serializer.setup(name)
assert serializer._dtype == np.dtype("float64")
Expand Down

0 comments on commit ec9db93

Please sign in to comment.