diff --git a/deeplake/api/tests/test_api.py b/deeplake/api/tests/test_api.py index 62ab22504d..02072ef620 100644 --- a/deeplake/api/tests/test_api.py +++ b/deeplake/api/tests/test_api.py @@ -17,6 +17,7 @@ from deeplake.core.storage import GCSProvider from deeplake.util.exceptions import ( GroupInfoNotSupportedError, + IncompatibleHtypeError, InvalidOperationError, SampleAppendError, TensorDoesNotExistError, @@ -2908,3 +2909,130 @@ def test_tensor_extend_ignore(local_ds, lfpw_links, compression_args): # Commit should work ds.commit() + + +def test_change_htype(local_ds_generator): + with local_ds_generator() as ds: + ds.create_tensor("images", sample_compression="jpg") + ds.images.extend(np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)) + + ds.create_tensor("labels") + ds.labels.extend([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + ds.create_tensor("boxes") + ds.boxes.extend(np.random.randn(10, 5, 4)) + + ds.create_tensor("boxes_3d") + ds.boxes_3d.extend(np.random.randn(10, 5, 8)) + + ds.create_tensor("embeddings") + ds.embeddings.extend(np.random.randn(10, 1536)) + + mask = np.zeros((10, 100, 100, 5), dtype=bool) + mask[:, :, :512, 1] = 1 + ds.create_tensor("masks") + ds.masks.extend(mask) + ds.create_tensor("image_masks", htype="image", sample_compression=None) + ds.image_masks.extend(mask) + + ds.create_tensor("keypoints") + ds.keypoints.extend(np.zeros((10, 9, 5))) + + ds.create_tensor("points") + ds.points.extend(np.zeros((10, 5, 3))) + + ds.images.htype = "image" + ds.labels.htype = "class_label" + ds.boxes.htype = "bbox" + ds.boxes_3d.htype = "bbox.3d" + ds.embeddings.htype = "embedding" + ds.masks.htype = "binary_mask" + ds.image_masks.htype = "binary_mask" + ds.keypoints.htype = "keypoints_coco" + ds.points.htype = "point" + + with local_ds_generator() as ds: + assert ds.images.htype == "image" + assert ds.labels.htype == "class_label" + assert ds.boxes.htype == "bbox" + assert ds.boxes_3d.htype == "bbox.3d" + assert ds.embeddings.htype == "embedding" + assert ds.masks.htype == "binary_mask" + assert ds.image_masks.htype == "binary_mask" + assert ds.keypoints.htype == "keypoints_coco" + assert ds.points.htype == "point" + + +def test_change_htype_fail(local_ds_generator): + with local_ds_generator() as ds: + ds.create_tensor("images") + ds.images.extend(np.zeros((10, 5, 5, 5, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.images.htype = "image" + + ds.create_tensor("images2") + ds.images2.extend(np.zeros((10, 5, 5, 6))) + with pytest.raises(IncompatibleHtypeError): + ds.images2.htype = "image" + + ds.create_tensor("labels") + ds.labels.extend(np.ones((10, 5, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.labels.htype = "class_label" + + ds.create_tensor("boxes") + ds.boxes.extend(np.zeros((10, 5, 5, 2))) + with pytest.raises(IncompatibleHtypeError): + ds.boxes.htype = "bbox" + with pytest.raises(IncompatibleHtypeError): + ds.boxes.htype = "bbox.3d" + + ds.create_tensor("boxes2") + ds.boxes2.extend(np.zeros((10, 5, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.boxes2.htype = "bbox" + with pytest.raises(IncompatibleHtypeError): + ds.boxes2.htype = "bbox.3d" + + ds.create_tensor("masks") + ds.masks.extend(np.zeros((10, 5, 5, 5, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.masks.htype = "binary_mask" + + ds.create_tensor("keypoints") + ds.keypoints.extend(np.zeros((10, 5, 5, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.keypoints.htype = "keypoints_coco" + + ds.create_tensor("keypoints2") + ds.keypoints2.extend(np.zeros((10, 10, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.keypoints2.htype = "keypoints_coco" + + ds.create_tensor("points") + ds.points.extend(np.zeros((10, 5, 5, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.points.htype = "point" + + ds.create_tensor("points2") + ds.points2.extend(np.zeros((10, 5, 5))) + with pytest.raises(IncompatibleHtypeError): + ds.points2.htype = "point" + + with pytest.raises(ValueError): + ds.images.htype = "link[image]" + + with pytest.raises(ValueError): + ds.images.htype = "sequence[image]" + + ds.create_tensor("boxes3", htype="bbox") + ds.boxes3.extend(np.zeros((10, 5, 4), dtype=np.float32)) + with pytest.raises(NotImplementedError): + ds.boxes3.htype = "embedding" + + with pytest.raises(NotImplementedError): + ds.images.htype = "text" + + ds.create_tensor("images3", htype="image", sample_compression="jpg") + with pytest.raises(UnsupportedCompressionError): + ds.images3.htype = "embedding" diff --git a/deeplake/core/dataset/deeplake_query_tensor.py b/deeplake/core/dataset/deeplake_query_tensor.py index 422e8c57bf..2ffa695df3 100644 --- a/deeplake/core/dataset/deeplake_query_tensor.py +++ b/deeplake/core/dataset/deeplake_query_tensor.py @@ -106,6 +106,10 @@ def htype(self): htype = f"link[{htype}]" return htype + @htype.setter + def htype(self, value): + raise NotImplementedError("htype of a query tensor cannot be set.") + @property def sample_compression(self): return self.indra_tensor.sample_compression diff --git a/deeplake/core/tensor.py b/deeplake/core/tensor.py index 1c40a93280..3733da505a 100644 --- a/deeplake/core/tensor.py +++ b/deeplake/core/tensor.py @@ -10,7 +10,7 @@ from typing import Dict, List, Sequence, Union, Optional, Tuple, Any, Callable from functools import reduce, partial from deeplake.core.index import Index, IndexEntry, replace_ellipsis_with_slices -from deeplake.core.meta.tensor_meta import TensorMeta +from deeplake.core.meta.tensor_meta import TensorMeta, _validate_htype_exists from deeplake.core.storage import StorageProvider from deeplake.core.chunk_engine import ChunkEngine from deeplake.core.compression import _read_timestamps @@ -42,6 +42,7 @@ TensorDoesNotExistError, InvalidKeyTypeError, TensorAlreadyExistsError, + UnsupportedCompressionError, ) from deeplake.util.iteration_warning import check_if_iteration from deeplake.hooks import dataset_read, dataset_written @@ -63,6 +64,12 @@ parse_mesh_to_dict, get_mesh_vertices, ) +from deeplake.util.htype import parse_complex_htype +from deeplake.htype import ( + HTYPE_CONVERSION_LHS, + HTYPE_CONSTRAINTS, + HTYPE_SUPPORTED_COMPRESSIONS, +) import warnings import webbrowser @@ -555,6 +562,15 @@ def htype(self): htype = f"link[{htype}]" return htype + @htype.setter + def htype(self, value): + self._check_compatibility_with_htype(value) + self.meta.htype = value + if value == "class_label": + self.meta._disable_temp_transform = False + self.meta.is_dirty = True + self.dataset.maybe_flush() + @property def hidden(self) -> bool: """Whether this tensor is a hidden tensor.""" @@ -1382,3 +1398,27 @@ def creds_key(self): def invalidate_libdeeplake_dataset(self): """Invalidates the libdeeplake dataset object.""" self.dataset.libdeeplake_dataset = None + + def _check_compatibility_with_htype(self, htype): + """Checks if the tensor is compatible with the given htype. + Raises an error if not compatible. + """ + is_sequence, is_link, htype = parse_complex_htype(htype) + if is_sequence or is_link: + raise ValueError(f"Cannot change htype to a sequence or link.") + _validate_htype_exists(htype) + if self.htype not in HTYPE_CONVERSION_LHS: + raise NotImplementedError( + f"Changing the htype of a tensor of htype {self.htype} is not supported." + ) + if htype not in HTYPE_CONSTRAINTS: + raise NotImplementedError( + f"Changing the htype to {htype} is not supported." + ) + compression = self.meta.sample_compression or self.meta.chunk_compression + if compression: + supported_compressions = HTYPE_SUPPORTED_COMPRESSIONS.get(htype) + if supported_compressions and compression not in supported_compressions: + raise UnsupportedCompressionError(compression, htype) + constraints = HTYPE_CONSTRAINTS[htype] + constraints(self.shape, self.dtype) diff --git a/deeplake/htype.py b/deeplake/htype.py index 4c9abbc854..fa0fb490f2 100644 --- a/deeplake/htype.py +++ b/deeplake/htype.py @@ -1,4 +1,6 @@ -from typing import Dict +from typing import Callable, Dict + +import numpy as np from deeplake.compression import ( IMAGE_COMPRESSIONS, VIDEO_COMPRESSIONS, @@ -8,6 +10,7 @@ POINT_CLOUD_COMPRESSIONS, MESH_COMPRESSIONS, ) +from deeplake.util.exceptions import IncompatibleHtypeError class htype: @@ -66,7 +69,7 @@ class htype: htype.BBOX: {"dtype": "float32", "coords": {}, "_info": ["coords"]}, htype.BBOX_3D: {"dtype": "float32", "coords": {}, "_info": ["coords"]}, htype.AUDIO: {"dtype": "float64"}, - htype.EMBEDDING: {"dtype": "float32"}, + htype.EMBEDDING: {}, htype.VIDEO: {"dtype": "uint8"}, htype.BINARY_MASK: { "dtype": "bool" @@ -98,7 +101,94 @@ class htype: htype.INTRINSICS: {"dtype": "float32"}, htype.POLYGON: {"dtype": "float32"}, htype.MESH: {"sample_compression": "ply"}, - htype.EMBEDDING: {}, +} + +HTYPE_CONVERSION_LHS = {htype.DEFAULT, htype.IMAGE} + + +class constraints: + """Constraints for converting a tensor to a htype""" + + ndim_error = ( + lambda htype, ndim: f"Incompatible number of dimensions for htype {htype}: {ndim}" + ) + shape_error = ( + lambda htype, shape: f"Incompatible shape of tensor for htype {htype}: {shape}" + ) + dtype_error = ( + lambda htype, dtype: f"Incompatible dtype of tensor for htype {htype}: {dtype}" + ) + + EMBEDDING = lambda shape, dtype: True + INSTANCE_LABEL = lambda shape, dtype: True + + @staticmethod + def IMAGE(shape, dtype): + if len(shape) not in (3, 4): + raise IncompatibleHtypeError(constraints.ndim_error("image", len(shape))) + if len(shape) == 4 and shape[-1] not in (1, 3, 4): + raise IncompatibleHtypeError(constraints.shape_error("image", shape)) + + @staticmethod + def CLASS_LABEL(shape, dtype): + if len(shape) != 2: + raise IncompatibleHtypeError( + constraints.ndim_error("class_label", len(shape)) + ) + + @staticmethod + def BBOX(shape, dtype): + if len(shape) not in (2, 3): + raise IncompatibleHtypeError(constraints.ndim_error("bbox", len(shape))) + if shape[-1] != 4: + raise IncompatibleHtypeError(constraints.shape_error("bbox", shape)) + + @staticmethod + def BBOX_3D(shape, dtype): + if len(shape) not in (2, 3): + raise IncompatibleHtypeError(constraints.ndim_error("bbox.3d", len(shape))) + if shape[-1] != 8: + raise IncompatibleHtypeError(constraints.shape_error("bbox.3d", shape)) + + @staticmethod + def BINARY_MASK(shape, dtype): + if len(shape) not in (3, 4): + raise IncompatibleHtypeError( + constraints.ndim_error("binary_mask", len(shape)) + ) + + SEGMENT_MASK = BINARY_MASK + + @staticmethod + def KEYPOINTS_COCO(shape, dtype): + if len(shape) != 3: + raise IncompatibleHtypeError( + constraints.ndim_error("keypoints_coco", len(shape)) + ) + if shape[1] % 3 != 0: + raise IncompatibleHtypeError( + constraints.shape_error("keypoints_coco", shape) + ) + + @staticmethod + def POINT(shape, dtype): + if len(shape) != 3: + raise IncompatibleHtypeError(constraints.ndim_error("point", len(shape))) + if shape[-1] not in (2, 3): + raise IncompatibleHtypeError(constraints.shape_error("point", shape)) + + +HTYPE_CONSTRAINTS: Dict[str, Callable] = { + htype.IMAGE: constraints.IMAGE, + htype.CLASS_LABEL: constraints.CLASS_LABEL, + htype.BBOX: constraints.BBOX, + htype.BBOX_3D: constraints.BBOX_3D, + htype.EMBEDDING: constraints.EMBEDDING, + htype.BINARY_MASK: constraints.BINARY_MASK, + htype.SEGMENT_MASK: constraints.SEGMENT_MASK, + htype.INSTANCE_LABEL: constraints.INSTANCE_LABEL, + htype.KEYPOINTS_COCO: constraints.KEYPOINTS_COCO, + htype.POINT: constraints.POINT, } HTYPE_VERIFICATIONS: Dict[str, Dict] = { diff --git a/deeplake/util/exceptions.py b/deeplake/util/exceptions.py index 5fa8c92403..b501c9f658 100644 --- a/deeplake/util/exceptions.py +++ b/deeplake/util/exceptions.py @@ -1,6 +1,5 @@ import numpy as np import deeplake -from deeplake.htype import HTYPE_CONFIGURATIONS from typing import Any, List, Sequence, Tuple, Optional, Union @@ -483,7 +482,7 @@ def __init__(self, expected: Union[np.dtype, str], actual: str, htype: str): # TODO: we may want to raise this error at the API level to determine if the user explicitly overwrote the `dtype` or not. (to make this error message more precise) # TODO: because if the user uses `dtype=np.uint8`, but the `htype` the tensor is created with has it's default dtype set as `uint8` also, then this message is ambiguous - htype_dtype = HTYPE_CONFIGURATIONS[htype].get("dtype", None) + htype_dtype = deeplake.HTYPE_CONFIGURATIONS[htype].get("dtype", None) if htype_dtype is not None and htype_dtype == expected: msg += f" Htype '{htype}' expects samples to have dtype='{htype_dtype}'." super().__init__("") @@ -1088,3 +1087,8 @@ def __init__(self): "Please either use different embedding function or exclude invalid " "files that are not supported by the embedding function. " ) + + +class IncompatibleHtypeError(Exception): + def __init__(self, msg): + super().__init__(msg) diff --git a/deeplake/util/htype.py b/deeplake/util/htype.py index 5c9b14981b..3aba8f713b 100644 --- a/deeplake/util/htype.py +++ b/deeplake/util/htype.py @@ -1,11 +1,9 @@ -# type: ignore - from typing import Tuple, Optional from deeplake.htype import htype as HTYPE, HTYPE_CONFIGURATIONS from deeplake.util.exceptions import TensorMetaInvalidHtype -def parse_complex_htype(htype: Optional[str]) -> Tuple[bool, bool, str]: +def parse_complex_htype(htype: Optional[str]) -> Tuple[bool, bool, Optional[str]]: is_sequence = False is_link = False