diff --git a/tensorbay/geometry/box.py b/tensorbay/geometry/box.py index 849e47e84..d89af1d47 100644 --- a/tensorbay/geometry/box.py +++ b/tensorbay/geometry/box.py @@ -407,6 +407,25 @@ def _line_intersect(length1: float, length2: float, midpoint_distance: float) -> intersect_length = min(line1_max, line2_max) - max(line1_min, line2_min) return intersect_length if intersect_length > 0 else 0 + def _allclose(self, other: _B3, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether this 3D box is close to another in value. + + Arguments: + other: The other object to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Returns: + A bool value indicating whether this vector is close to another. + + """ + # pylint: disable=protected-access + return self._size._allclose( + other.size, rel_tol=rel_tol, abs_tol=abs_tol + ) and self._transform._allclose(other.transform, rel_tol=rel_tol, abs_tol=abs_tol) + def _loads(self, contents: Dict[str, Dict[str, float]]) -> None: self._size = Vector3D.loads(contents["size"]) self._transform = Transform3D.loads(contents) diff --git a/tensorbay/geometry/tests/test_box.py b/tensorbay/geometry/tests/test_box.py index 947603bc2..e2281602b 100644 --- a/tensorbay/geometry/tests/test_box.py +++ b/tensorbay/geometry/tests/test_box.py @@ -6,7 +6,7 @@ import pytest from quaternion import quaternion -from ...utility import UserSequence +from ...utility import UserSequence, allclose from .. import Box2D, Box3D, Transform3D, Vector2D, Vector3D _DATA_2D = {"xmin": 1.0, "ymin": 2.0, "xmax": 3.0, "ymax": 4.0} @@ -116,10 +116,13 @@ def test_rmul(self): assert box3d.__rmul__(transform) == Box3D( size=(1, 1, 1), translation=[2, 0, 0], rotation=quaternion(-1, 0, 0, 0) ) - assert box3d.__rmul__(quaternion_1) == Box3D( - size=(1, 1, 1), - translation=[1.7999999999999996, 2, 2.6], - rotation=quaternion(-2, 1, 4, -3), + assert allclose( + box3d.__rmul__(quaternion_1), + Box3D( + size=(1, 1, 1), + translation=[1.7999999999999996, 2, 2.6], + rotation=quaternion(-2, 1, 4, -3), + ), ) assert box3d.__rmul__(1) == NotImplemented diff --git a/tensorbay/geometry/tests/test_vector.py b/tensorbay/geometry/tests/test_vector.py index 5f3480e49..b42407a3b 100644 --- a/tensorbay/geometry/tests/test_vector.py +++ b/tensorbay/geometry/tests/test_vector.py @@ -125,6 +125,11 @@ def test_abs(self): assert abs(Vector(1, 1)) == 1.4142135623730951 assert abs(Vector(1, 1, 1)) == 1.7320508075688772 + def test__allclose(self): + assert Vector(1, 2)._allclose(Vector2D(1.000000000001, 2)) + assert Vector(1, 2, 3)._allclose(Vector3D(1.000000000001, 2, 2.999999999996)) + assert not Vector(1, 2, 3)._allclose(Vector3D(1.100000000001, 2, 2.999999999996)) + def test_repr_head(self): vector = Vector(1, 2) assert vector._repr_head() == "Vector2D(1, 2)" diff --git a/tensorbay/geometry/transform.py b/tensorbay/geometry/transform.py index 1d08b5aa7..b088f2077 100644 --- a/tensorbay/geometry/transform.py +++ b/tensorbay/geometry/transform.py @@ -23,7 +23,9 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - from quaternion import as_rotation_matrix, from_rotation_matrix, quaternion, rotate_vectors + from quaternion import as_rotation_matrix, from_rotation_matrix + from quaternion import isclose as quaternion_isclose + from quaternion import quaternion, rotate_vectors _T = TypeVar("_T", bound="Transform3D") @@ -154,6 +156,27 @@ def _mul_vector(self, other: Iterable[float]) -> Vector3D: # __radd__ is used to ensure the shape of the input object. return self._translation.__radd__(rotate_vectors(self._rotation, other)) + def _allclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether this 3D transform is close to another in value. + + Arguments: + other: The other object to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Returns: + A bool value indicating whether this vector is close to another. + + """ + if not isinstance(other, self.__class__): + return False + + return self._translation._allclose( # pylint: disable=protected-access + other.translation, rel_tol=rel_tol, abs_tol=abs_tol + ) and all(quaternion_isclose(self._rotation, other.rotation)) + def _loads(self, contents: Dict[str, Dict[str, float]]) -> None: self._translation = Vector3D.loads(contents["translation"]) rotation_contents = contents["rotation"] diff --git a/tensorbay/geometry/vector.py b/tensorbay/geometry/vector.py index 756c63633..cb5b0fd5a 100644 --- a/tensorbay/geometry/vector.py +++ b/tensorbay/geometry/vector.py @@ -15,7 +15,9 @@ """ from itertools import zip_longest -from math import hypot, sqrt +from math import hypot +from math import isclose as math_isclose +from math import sqrt from sys import version_info from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, TypeVar, Union @@ -180,6 +182,35 @@ def __abs__(self) -> float: def _repr_head(self) -> str: return f"{self.__class__.__name__}{self._data}" + def _allclose( + self, other: Iterable[float], *, rel_tol: float = 1e-09, abs_tol: float = 0.0 + ) -> bool: + """Determine whether this vector is close to another in value. + + Arguments: + other: The other object to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Raises: + TypeError: When other have inconsistent dimension. + + Returns: + A bool value indicating whether this vector is close to another. + + """ + try: + return all( + math_isclose(i, j, rel_tol=rel_tol, abs_tol=abs_tol) + for i, j in zip_longest(self._data, other) + ) + except TypeError as error: + raise TypeError( + f"The other object must have the dimension of {self._DIMENSION}" + ) from error + @staticmethod def loads(contents: Dict[str, float]) -> _T: """Loads a :class:`Vector` from a dict containing coordinates of the vector. diff --git a/tensorbay/label/basic.py b/tensorbay/label/basic.py index 4b0c5e720..93dc71d9c 100644 --- a/tensorbay/label/basic.py +++ b/tensorbay/label/basic.py @@ -143,6 +143,7 @@ class _LabelBase(AttrsMixin, TypeMixin[LabelType], ReprMixin): _label_attrs: Tuple[str, ...] = ("category", "attributes", "instance") _repr_attrs = _label_attrs + _support_allclose = True _AttributeType = Dict[str, Union[str, int, float, bool, List[Union[str, int, float, bool]]]] diff --git a/tensorbay/label/tests/test_label_box.py b/tensorbay/label/tests/test_label_box.py index d36889a82..8291e2888 100644 --- a/tensorbay/label/tests/test_label_box.py +++ b/tensorbay/label/tests/test_label_box.py @@ -7,6 +7,7 @@ from quaternion import quaternion from ...geometry import Transform3D, Vector3D +from ...utility import allclose from .. import Box2DSubcatalog, Box3DSubcatalog, LabeledBox2D, LabeledBox3D @@ -156,13 +157,16 @@ def test_rmul(self): instance="12345", ) - assert labeledbox3d.__rmul__(quaternion_1) == LabeledBox3D( - size=size, - translation=[1.7999999999999996, 2, 2.6], - rotation=[-2, 1, 4, -3], - category="cat", - attributes={"gender": "male"}, - instance="12345", + assert allclose( + labeledbox3d.__rmul__(quaternion_1), + LabeledBox3D( + size=size, + translation=[1.7999999999999996, 2, 2.6], + rotation=[-2, 1, 4, -3], + category="cat", + attributes={"gender": "male"}, + instance="12345", + ), ) assert labeledbox3d.__rmul__(1) == NotImplemented diff --git a/tensorbay/utility/__init__.py b/tensorbay/utility/__init__.py index 63da720ab..7da65f6a0 100644 --- a/tensorbay/utility/__init__.py +++ b/tensorbay/utility/__init__.py @@ -13,6 +13,7 @@ EqMixin, KwargsDeprecated, MatrixType, + allclose, common_loads, locked, ) @@ -42,6 +43,7 @@ "UserMutableMapping", "UserMutableSequence", "UserSequence", + "allclose", "attr", "attr_base", "camel", diff --git a/tensorbay/utility/attr.py b/tensorbay/utility/attr.py index 771380788..2e369cf1f 100644 --- a/tensorbay/utility/attr.py +++ b/tensorbay/utility/attr.py @@ -10,6 +10,7 @@ :class:`Field` is a class describing the attr related fields. """ +from math import isclose as math_isclose from sys import version_info from typing import ( Any, @@ -64,6 +65,7 @@ def __init__( ) -> None: self.loader: _Callable self.dumper: _Callable + self.allclose: Callable[[Any, Any, float, float], bool] self.is_dynamic = is_dynamic self.default = default @@ -94,6 +96,7 @@ def __init__(self, key: Optional[str]) -> None: self.loader: _Callable self.dumper: _Callable self.key = key + self.allclose: Callable[[Any, Any, float, float], bool] class AttrsMixin: @@ -109,6 +112,7 @@ class AttrsMixin: def __init_subclass__(cls) -> None: type_ = cls.__annotations__.pop(_ATTRS_BASE, None) + support_allclose = getattr(cls, "_support_allclose", False) if type_: cls._attrs_base.loader = type_._loads # pylint: disable=protected-access cls._attrs_base.dumper = type_.dumps @@ -122,6 +126,8 @@ def __init_subclass__(cls) -> None: field = getattr(cls, name, None) if isinstance(field, Field): field.loader, field.dumper = _get_operators(type_) + if support_allclose: + field.allclose = _get_allclose(type_) if hasattr(field, "key_converter"): field.key = field.key_converter(name) attrs_fields[name] = field @@ -200,6 +206,32 @@ def _dumps(self) -> Dict[str, Any]: _key_dumper(field.key, contents, field.dumper(value)) return contents + def _allclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether this instance is close to another in value. + + Arguments: + other: The other instance to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Returns: + A bool value indicating whether this instance is close to another. + + """ + result = True + + for name, field in self._attrs_fields.items(): + if not hasattr(self, name): + continue + + result = result and field.allclose( # type: ignore[call-arg] + getattr(self, name), getattr(other, name), rel_tol=rel_tol, abs_tol=abs_tol + ) + + return result + def attr( *, @@ -292,6 +324,43 @@ def _get_origin_in_3_6(annotation: Any) -> Any: _get_origin = _get_origin_in_3_6 if version_info < (3, 7) else _get_origin_in_3_7 +def _get_allclose(annotation: Any) -> Callable[[Any, Any, float, float], bool]: + """Get attr allclose methods by annotations. + + AttrsMixin has three operating types which are classified by attr annotation. + 1. builtin types, like str, int, None + 2. tensorbay custom class, like tensorbay.label.Classification + 3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D] + + Arguments: + annotation: Type of the attr. + + Returns: + The ``_allclose`` methods of the annotation. + + """ + origin = _get_origin(annotation) + if isinstance(origin, type) and issubclass(origin, Sequence): + type_ = annotation.__args__[0] + return lambda self, other, rel_tol=1e-09, abs_tol=0.0: all( # type: ignore[misc] + _get_allclose(type_)(i, j, rel_tol=rel_tol, abs_tol=abs_tol) # type: ignore[call-arg] + for i, j in zip(self, other) + ) + + type_ = annotation + + mod = getattr(type_, "__module__", None) + if mod in _BUILTINS: + if type_ in (int, float): + return math_isclose # type: ignore[return-value] + return _eq_allclose # type: ignore[return-value] + return type_._allclose # type: ignore[no-any-return] # pylint: disable=protected-access + + +def _eq_allclose(object_1: object, object_2: object, **_: float) -> bool: + return object_1 == object_2 + + def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]: """Get attr operating methods by annotations. diff --git a/tensorbay/utility/common.py b/tensorbay/utility/common.py index f239faec8..42d71c8f3 100644 --- a/tensorbay/utility/common.py +++ b/tensorbay/utility/common.py @@ -20,6 +20,7 @@ from typing import Any, Callable, DefaultDict, Optional, Sequence, Tuple, Type, TypeVar, Union import numpy as np +from typing_extensions import Protocol _T = TypeVar("_T") _Callable = TypeVar("_Callable", bound=Callable[..., Any]) @@ -45,6 +46,53 @@ def common_loads(object_class: Type[_T], contents: Any) -> _T: return obj +class _A(Protocol): # pylint: disable=too-few-public-methods + def _allclose(self, other: "_A", *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Tell if all the data is close to the other object. + + Arguments: + other: The other object. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + """ + + +def allclose(object_1: _A, object_2: _A, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether object_1 is close to object_2 in value. + + Arguments: + object_1: The first object to compare. + object_2: The second object to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Returns: + A bool value indicating whether object_1 is close to object_2. + + """ + try: + # pylint: disable=protected-access + if issubclass(object_1.__class__, object_2.__class__) and hasattr(object_1, "_allclose"): + print(1) + return object_1._allclose(object_2, rel_tol=rel_tol, abs_tol=abs_tol) + if issubclass(object_2.__class__, object_1.__class__) and hasattr(object_2, "_allclose"): + print(2) + return object_2._allclose(object_1, rel_tol=rel_tol, abs_tol=abs_tol) + + if hasattr(object_1, "_allclose"): + print(3) + return object_1._allclose(object_2, rel_tol=rel_tol, abs_tol=abs_tol) + print(4) + return object_2._allclose(object_1, rel_tol=rel_tol, abs_tol=abs_tol) + except Exception: # pylint: disable=broad-except + print(5) + return False + + class EqMixin: # pylint: disable=too-few-public-methods """A mixin class to support __eq__() method.