Skip to content

Commit

Permalink
Merge pull request #2590 from activeloopai/fy_change_htype
Browse files Browse the repository at this point in the history
[AL-2405] Add ability to change htype
  • Loading branch information
FayazRahman authored Sep 14, 2023
2 parents 424c4e0 + 1eba9b1 commit 3401ae1
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 9 deletions.
128 changes: 128 additions & 0 deletions deeplake/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from deeplake.core.storage import GCSProvider
from deeplake.util.exceptions import (
GroupInfoNotSupportedError,
IncompatibleHtypeError,
InvalidOperationError,
SampleAppendError,
TensorDoesNotExistError,
Expand Down Expand Up @@ -2908,3 +2909,130 @@ def test_tensor_extend_ignore(local_ds, lfpw_links, compression_args):

# Commit should work
ds.commit()


def test_change_htype(local_ds_generator):
with local_ds_generator() as ds:
ds.create_tensor("images", sample_compression="jpg")
ds.images.extend(np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8))

ds.create_tensor("labels")
ds.labels.extend([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

ds.create_tensor("boxes")
ds.boxes.extend(np.random.randn(10, 5, 4))

ds.create_tensor("boxes_3d")
ds.boxes_3d.extend(np.random.randn(10, 5, 8))

ds.create_tensor("embeddings")
ds.embeddings.extend(np.random.randn(10, 1536))

mask = np.zeros((10, 100, 100, 5), dtype=bool)
mask[:, :, :512, 1] = 1
ds.create_tensor("masks")
ds.masks.extend(mask)
ds.create_tensor("image_masks", htype="image", sample_compression=None)
ds.image_masks.extend(mask)

ds.create_tensor("keypoints")
ds.keypoints.extend(np.zeros((10, 9, 5)))

ds.create_tensor("points")
ds.points.extend(np.zeros((10, 5, 3)))

ds.images.htype = "image"
ds.labels.htype = "class_label"
ds.boxes.htype = "bbox"
ds.boxes_3d.htype = "bbox.3d"
ds.embeddings.htype = "embedding"
ds.masks.htype = "binary_mask"
ds.image_masks.htype = "binary_mask"
ds.keypoints.htype = "keypoints_coco"
ds.points.htype = "point"

with local_ds_generator() as ds:
assert ds.images.htype == "image"
assert ds.labels.htype == "class_label"
assert ds.boxes.htype == "bbox"
assert ds.boxes_3d.htype == "bbox.3d"
assert ds.embeddings.htype == "embedding"
assert ds.masks.htype == "binary_mask"
assert ds.image_masks.htype == "binary_mask"
assert ds.keypoints.htype == "keypoints_coco"
assert ds.points.htype == "point"


def test_change_htype_fail(local_ds_generator):
with local_ds_generator() as ds:
ds.create_tensor("images")
ds.images.extend(np.zeros((10, 5, 5, 5, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.images.htype = "image"

ds.create_tensor("images2")
ds.images2.extend(np.zeros((10, 5, 5, 6)))
with pytest.raises(IncompatibleHtypeError):
ds.images2.htype = "image"

ds.create_tensor("labels")
ds.labels.extend(np.ones((10, 5, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.labels.htype = "class_label"

ds.create_tensor("boxes")
ds.boxes.extend(np.zeros((10, 5, 5, 2)))
with pytest.raises(IncompatibleHtypeError):
ds.boxes.htype = "bbox"
with pytest.raises(IncompatibleHtypeError):
ds.boxes.htype = "bbox.3d"

ds.create_tensor("boxes2")
ds.boxes2.extend(np.zeros((10, 5, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.boxes2.htype = "bbox"
with pytest.raises(IncompatibleHtypeError):
ds.boxes2.htype = "bbox.3d"

ds.create_tensor("masks")
ds.masks.extend(np.zeros((10, 5, 5, 5, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.masks.htype = "binary_mask"

ds.create_tensor("keypoints")
ds.keypoints.extend(np.zeros((10, 5, 5, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.keypoints.htype = "keypoints_coco"

ds.create_tensor("keypoints2")
ds.keypoints2.extend(np.zeros((10, 10, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.keypoints2.htype = "keypoints_coco"

ds.create_tensor("points")
ds.points.extend(np.zeros((10, 5, 5, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.points.htype = "point"

ds.create_tensor("points2")
ds.points2.extend(np.zeros((10, 5, 5)))
with pytest.raises(IncompatibleHtypeError):
ds.points2.htype = "point"

with pytest.raises(ValueError):
ds.images.htype = "link[image]"

with pytest.raises(ValueError):
ds.images.htype = "sequence[image]"

ds.create_tensor("boxes3", htype="bbox")
ds.boxes3.extend(np.zeros((10, 5, 4), dtype=np.float32))
with pytest.raises(NotImplementedError):
ds.boxes3.htype = "embedding"

with pytest.raises(NotImplementedError):
ds.images.htype = "text"

ds.create_tensor("images3", htype="image", sample_compression="jpg")
with pytest.raises(UnsupportedCompressionError):
ds.images3.htype = "embedding"
4 changes: 4 additions & 0 deletions deeplake/core/dataset/deeplake_query_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def htype(self):
htype = f"link[{htype}]"
return htype

@htype.setter
def htype(self, value):
raise NotImplementedError("htype of a query tensor cannot be set.")

@property
def sample_compression(self):
return self.indra_tensor.sample_compression
Expand Down
42 changes: 41 additions & 1 deletion deeplake/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, List, Sequence, Union, Optional, Tuple, Any, Callable
from functools import reduce, partial
from deeplake.core.index import Index, IndexEntry, replace_ellipsis_with_slices
from deeplake.core.meta.tensor_meta import TensorMeta
from deeplake.core.meta.tensor_meta import TensorMeta, _validate_htype_exists
from deeplake.core.storage import StorageProvider
from deeplake.core.chunk_engine import ChunkEngine
from deeplake.core.compression import _read_timestamps
Expand Down Expand Up @@ -42,6 +42,7 @@
TensorDoesNotExistError,
InvalidKeyTypeError,
TensorAlreadyExistsError,
UnsupportedCompressionError,
)
from deeplake.util.iteration_warning import check_if_iteration
from deeplake.hooks import dataset_read, dataset_written
Expand All @@ -63,6 +64,12 @@
parse_mesh_to_dict,
get_mesh_vertices,
)
from deeplake.util.htype import parse_complex_htype
from deeplake.htype import (
HTYPE_CONVERSION_LHS,
HTYPE_CONSTRAINTS,
HTYPE_SUPPORTED_COMPRESSIONS,
)
import warnings
import webbrowser

Expand Down Expand Up @@ -555,6 +562,15 @@ def htype(self):
htype = f"link[{htype}]"
return htype

@htype.setter
def htype(self, value):
self._check_compatibility_with_htype(value)
self.meta.htype = value
if value == "class_label":
self.meta._disable_temp_transform = False
self.meta.is_dirty = True
self.dataset.maybe_flush()

@property
def hidden(self) -> bool:
"""Whether this tensor is a hidden tensor."""
Expand Down Expand Up @@ -1382,3 +1398,27 @@ def creds_key(self):
def invalidate_libdeeplake_dataset(self):
"""Invalidates the libdeeplake dataset object."""
self.dataset.libdeeplake_dataset = None

def _check_compatibility_with_htype(self, htype):
"""Checks if the tensor is compatible with the given htype.
Raises an error if not compatible.
"""
is_sequence, is_link, htype = parse_complex_htype(htype)
if is_sequence or is_link:
raise ValueError(f"Cannot change htype to a sequence or link.")
_validate_htype_exists(htype)
if self.htype not in HTYPE_CONVERSION_LHS:
raise NotImplementedError(
f"Changing the htype of a tensor of htype {self.htype} is not supported."
)
if htype not in HTYPE_CONSTRAINTS:
raise NotImplementedError(
f"Changing the htype to {htype} is not supported."
)
compression = self.meta.sample_compression or self.meta.chunk_compression
if compression:
supported_compressions = HTYPE_SUPPORTED_COMPRESSIONS.get(htype)
if supported_compressions and compression not in supported_compressions:
raise UnsupportedCompressionError(compression, htype)
constraints = HTYPE_CONSTRAINTS[htype]
constraints(self.shape, self.dtype)
96 changes: 93 additions & 3 deletions deeplake/htype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Dict
from typing import Callable, Dict

import numpy as np
from deeplake.compression import (
IMAGE_COMPRESSIONS,
VIDEO_COMPRESSIONS,
Expand All @@ -8,6 +10,7 @@
POINT_CLOUD_COMPRESSIONS,
MESH_COMPRESSIONS,
)
from deeplake.util.exceptions import IncompatibleHtypeError


class htype:
Expand Down Expand Up @@ -66,7 +69,7 @@ class htype:
htype.BBOX: {"dtype": "float32", "coords": {}, "_info": ["coords"]},
htype.BBOX_3D: {"dtype": "float32", "coords": {}, "_info": ["coords"]},
htype.AUDIO: {"dtype": "float64"},
htype.EMBEDDING: {"dtype": "float32"},
htype.EMBEDDING: {},
htype.VIDEO: {"dtype": "uint8"},
htype.BINARY_MASK: {
"dtype": "bool"
Expand Down Expand Up @@ -98,7 +101,94 @@ class htype:
htype.INTRINSICS: {"dtype": "float32"},
htype.POLYGON: {"dtype": "float32"},
htype.MESH: {"sample_compression": "ply"},
htype.EMBEDDING: {},
}

HTYPE_CONVERSION_LHS = {htype.DEFAULT, htype.IMAGE}


class constraints:
"""Constraints for converting a tensor to a htype"""

ndim_error = (
lambda htype, ndim: f"Incompatible number of dimensions for htype {htype}: {ndim}"
)
shape_error = (
lambda htype, shape: f"Incompatible shape of tensor for htype {htype}: {shape}"
)
dtype_error = (
lambda htype, dtype: f"Incompatible dtype of tensor for htype {htype}: {dtype}"
)

EMBEDDING = lambda shape, dtype: True
INSTANCE_LABEL = lambda shape, dtype: True

@staticmethod
def IMAGE(shape, dtype):
if len(shape) not in (3, 4):
raise IncompatibleHtypeError(constraints.ndim_error("image", len(shape)))
if len(shape) == 4 and shape[-1] not in (1, 3, 4):
raise IncompatibleHtypeError(constraints.shape_error("image", shape))

@staticmethod
def CLASS_LABEL(shape, dtype):
if len(shape) != 2:
raise IncompatibleHtypeError(
constraints.ndim_error("class_label", len(shape))
)

@staticmethod
def BBOX(shape, dtype):
if len(shape) not in (2, 3):
raise IncompatibleHtypeError(constraints.ndim_error("bbox", len(shape)))
if shape[-1] != 4:
raise IncompatibleHtypeError(constraints.shape_error("bbox", shape))

@staticmethod
def BBOX_3D(shape, dtype):
if len(shape) not in (2, 3):
raise IncompatibleHtypeError(constraints.ndim_error("bbox.3d", len(shape)))
if shape[-1] != 8:
raise IncompatibleHtypeError(constraints.shape_error("bbox.3d", shape))

@staticmethod
def BINARY_MASK(shape, dtype):
if len(shape) not in (3, 4):
raise IncompatibleHtypeError(
constraints.ndim_error("binary_mask", len(shape))
)

SEGMENT_MASK = BINARY_MASK

@staticmethod
def KEYPOINTS_COCO(shape, dtype):
if len(shape) != 3:
raise IncompatibleHtypeError(
constraints.ndim_error("keypoints_coco", len(shape))
)
if shape[1] % 3 != 0:
raise IncompatibleHtypeError(
constraints.shape_error("keypoints_coco", shape)
)

@staticmethod
def POINT(shape, dtype):
if len(shape) != 3:
raise IncompatibleHtypeError(constraints.ndim_error("point", len(shape)))
if shape[-1] not in (2, 3):
raise IncompatibleHtypeError(constraints.shape_error("point", shape))


HTYPE_CONSTRAINTS: Dict[str, Callable] = {
htype.IMAGE: constraints.IMAGE,
htype.CLASS_LABEL: constraints.CLASS_LABEL,
htype.BBOX: constraints.BBOX,
htype.BBOX_3D: constraints.BBOX_3D,
htype.EMBEDDING: constraints.EMBEDDING,
htype.BINARY_MASK: constraints.BINARY_MASK,
htype.SEGMENT_MASK: constraints.SEGMENT_MASK,
htype.INSTANCE_LABEL: constraints.INSTANCE_LABEL,
htype.KEYPOINTS_COCO: constraints.KEYPOINTS_COCO,
htype.POINT: constraints.POINT,
}

HTYPE_VERIFICATIONS: Dict[str, Dict] = {
Expand Down
8 changes: 6 additions & 2 deletions deeplake/util/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import deeplake
from deeplake.htype import HTYPE_CONFIGURATIONS
from typing import Any, List, Sequence, Tuple, Optional, Union


Expand Down Expand Up @@ -483,7 +482,7 @@ def __init__(self, expected: Union[np.dtype, str], actual: str, htype: str):

# TODO: we may want to raise this error at the API level to determine if the user explicitly overwrote the `dtype` or not. (to make this error message more precise)
# TODO: because if the user uses `dtype=np.uint8`, but the `htype` the tensor is created with has it's default dtype set as `uint8` also, then this message is ambiguous
htype_dtype = HTYPE_CONFIGURATIONS[htype].get("dtype", None)
htype_dtype = deeplake.HTYPE_CONFIGURATIONS[htype].get("dtype", None)
if htype_dtype is not None and htype_dtype == expected:
msg += f" Htype '{htype}' expects samples to have dtype='{htype_dtype}'."
super().__init__("")
Expand Down Expand Up @@ -1088,3 +1087,8 @@ def __init__(self):
"Please either use different embedding function or exclude invalid "
"files that are not supported by the embedding function. "
)


class IncompatibleHtypeError(Exception):
def __init__(self, msg):
super().__init__(msg)
4 changes: 1 addition & 3 deletions deeplake/util/htype.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# type: ignore

from typing import Tuple, Optional
from deeplake.htype import htype as HTYPE, HTYPE_CONFIGURATIONS
from deeplake.util.exceptions import TensorMetaInvalidHtype


def parse_complex_htype(htype: Optional[str]) -> Tuple[bool, bool, str]:
def parse_complex_htype(htype: Optional[str]) -> Tuple[bool, bool, Optional[str]]:
is_sequence = False
is_link = False

Expand Down

0 comments on commit 3401ae1

Please sign in to comment.