Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add "isclose" method for Vector, Box3D and LabeledBox3D #806

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tensorbay/geometry/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,25 @@ def _line_intersect(length1: float, length2: float, midpoint_distance: float) ->
intersect_length = min(line1_max, line2_max) - max(line1_min, line2_min)
return intersect_length if intersect_length > 0 else 0

def _allclose(self, other: _B3, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Determine whether this 3D box is close to another in value.

Arguments:
other: The other object to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values

Returns:
A bool value indicating whether this vector is close to another.

"""
# pylint: disable=protected-access
return self._size._allclose(
other.size, rel_tol=rel_tol, abs_tol=abs_tol
) and self._transform._allclose(other.transform, rel_tol=rel_tol, abs_tol=abs_tol)

def _loads(self, contents: Dict[str, Dict[str, float]]) -> None:
self._size = Vector3D.loads(contents["size"])
self._transform = Transform3D.loads(contents)
Expand Down
13 changes: 8 additions & 5 deletions tensorbay/geometry/tests/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from quaternion import quaternion

from ...utility import UserSequence
from ...utility import UserSequence, allclose
from .. import Box2D, Box3D, Transform3D, Vector2D, Vector3D

_DATA_2D = {"xmin": 1.0, "ymin": 2.0, "xmax": 3.0, "ymax": 4.0}
Expand Down Expand Up @@ -116,10 +116,13 @@ def test_rmul(self):
assert box3d.__rmul__(transform) == Box3D(
size=(1, 1, 1), translation=[2, 0, 0], rotation=quaternion(-1, 0, 0, 0)
)
assert box3d.__rmul__(quaternion_1) == Box3D(
size=(1, 1, 1),
translation=[1.7999999999999996, 2, 2.6],
rotation=quaternion(-2, 1, 4, -3),
assert allclose(
box3d.__rmul__(quaternion_1),
Box3D(
size=(1, 1, 1),
translation=[1.7999999999999996, 2, 2.6],
rotation=quaternion(-2, 1, 4, -3),
),
)
assert box3d.__rmul__(1) == NotImplemented

Expand Down
5 changes: 5 additions & 0 deletions tensorbay/geometry/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def test_abs(self):
assert abs(Vector(1, 1)) == 1.4142135623730951
assert abs(Vector(1, 1, 1)) == 1.7320508075688772

def test__allclose(self):
assert Vector(1, 2)._allclose(Vector2D(1.000000000001, 2))
assert Vector(1, 2, 3)._allclose(Vector3D(1.000000000001, 2, 2.999999999996))
assert not Vector(1, 2, 3)._allclose(Vector3D(1.100000000001, 2, 2.999999999996))

def test_repr_head(self):
vector = Vector(1, 2)
assert vector._repr_head() == "Vector2D(1, 2)"
Expand Down
25 changes: 24 additions & 1 deletion tensorbay/geometry/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

with warnings.catch_warnings():
warnings.simplefilter("ignore")
from quaternion import as_rotation_matrix, from_rotation_matrix, quaternion, rotate_vectors
from quaternion import as_rotation_matrix, from_rotation_matrix
from quaternion import isclose as quaternion_isclose
from quaternion import quaternion, rotate_vectors

_T = TypeVar("_T", bound="Transform3D")

Expand Down Expand Up @@ -154,6 +156,27 @@ def _mul_vector(self, other: Iterable[float]) -> Vector3D:
# __radd__ is used to ensure the shape of the input object.
return self._translation.__radd__(rotate_vectors(self._rotation, other))

def _allclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Determine whether this 3D transform is close to another in value.

Arguments:
other: The other object to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values

Returns:
A bool value indicating whether this vector is close to another.

"""
if not isinstance(other, self.__class__):
return False

return self._translation._allclose( # pylint: disable=protected-access
other.translation, rel_tol=rel_tol, abs_tol=abs_tol
) and all(quaternion_isclose(self._rotation, other.rotation))

def _loads(self, contents: Dict[str, Dict[str, float]]) -> None:
self._translation = Vector3D.loads(contents["translation"])
rotation_contents = contents["rotation"]
Expand Down
33 changes: 32 additions & 1 deletion tensorbay/geometry/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""

from itertools import zip_longest
from math import hypot, sqrt
from math import hypot
from math import isclose as math_isclose
from math import sqrt
from sys import version_info
from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, TypeVar, Union

Expand Down Expand Up @@ -180,6 +182,35 @@ def __abs__(self) -> float:
def _repr_head(self) -> str:
return f"{self.__class__.__name__}{self._data}"

def _allclose(
self, other: Iterable[float], *, rel_tol: float = 1e-09, abs_tol: float = 0.0
) -> bool:
"""Determine whether this vector is close to another in value.

Arguments:
other: The other object to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values

Raises:
TypeError: When other have inconsistent dimension.

Returns:
A bool value indicating whether this vector is close to another.

"""
try:
return all(
math_isclose(i, j, rel_tol=rel_tol, abs_tol=abs_tol)
for i, j in zip_longest(self._data, other)
)
except TypeError as error:
raise TypeError(
f"The other object must have the dimension of {self._DIMENSION}"
) from error

@staticmethod
def loads(contents: Dict[str, float]) -> _T:
"""Loads a :class:`Vector` from a dict containing coordinates of the vector.
Expand Down
1 change: 1 addition & 0 deletions tensorbay/label/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class _LabelBase(AttrsMixin, TypeMixin[LabelType], ReprMixin):
_label_attrs: Tuple[str, ...] = ("category", "attributes", "instance")

_repr_attrs = _label_attrs
_support_allclose = True

_AttributeType = Dict[str, Union[str, int, float, bool, List[Union[str, int, float, bool]]]]

Expand Down
18 changes: 11 additions & 7 deletions tensorbay/label/tests/test_label_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from quaternion import quaternion

from ...geometry import Transform3D, Vector3D
from ...utility import allclose
from .. import Box2DSubcatalog, Box3DSubcatalog, LabeledBox2D, LabeledBox3D


Expand Down Expand Up @@ -156,13 +157,16 @@ def test_rmul(self):
instance="12345",
)

assert labeledbox3d.__rmul__(quaternion_1) == LabeledBox3D(
size=size,
translation=[1.7999999999999996, 2, 2.6],
rotation=[-2, 1, 4, -3],
category="cat",
attributes={"gender": "male"},
instance="12345",
assert allclose(
labeledbox3d.__rmul__(quaternion_1),
LabeledBox3D(
size=size,
translation=[1.7999999999999996, 2, 2.6],
rotation=[-2, 1, 4, -3],
category="cat",
attributes={"gender": "male"},
instance="12345",
),
)

assert labeledbox3d.__rmul__(1) == NotImplemented
Expand Down
2 changes: 2 additions & 0 deletions tensorbay/utility/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
EqMixin,
KwargsDeprecated,
MatrixType,
allclose,
common_loads,
locked,
)
Expand Down Expand Up @@ -42,6 +43,7 @@
"UserMutableMapping",
"UserMutableSequence",
"UserSequence",
"allclose",
"attr",
"attr_base",
"camel",
Expand Down
69 changes: 69 additions & 0 deletions tensorbay/utility/attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
:class:`Field` is a class describing the attr related fields.

"""
from math import isclose as math_isclose
from sys import version_info
from typing import (
Any,
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
) -> None:
self.loader: _Callable
self.dumper: _Callable
self.allclose: Callable[[Any, Any, float, float], bool]

self.is_dynamic = is_dynamic
self.default = default
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(self, key: Optional[str]) -> None:
self.loader: _Callable
self.dumper: _Callable
self.key = key
self.allclose: Callable[[Any, Any, float, float], bool]


class AttrsMixin:
Expand All @@ -109,6 +112,7 @@ class AttrsMixin:

def __init_subclass__(cls) -> None:
type_ = cls.__annotations__.pop(_ATTRS_BASE, None)
support_allclose = getattr(cls, "_support_allclose", False)
if type_:
cls._attrs_base.loader = type_._loads # pylint: disable=protected-access
cls._attrs_base.dumper = type_.dumps
Expand All @@ -122,6 +126,8 @@ def __init_subclass__(cls) -> None:
field = getattr(cls, name, None)
if isinstance(field, Field):
field.loader, field.dumper = _get_operators(type_)
if support_allclose:
field.allclose = _get_allclose(type_)
if hasattr(field, "key_converter"):
field.key = field.key_converter(name)
attrs_fields[name] = field
Expand Down Expand Up @@ -200,6 +206,32 @@ def _dumps(self) -> Dict[str, Any]:
_key_dumper(field.key, contents, field.dumper(value))
return contents

def _allclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Determine whether this instance is close to another in value.

Arguments:
other: The other instance to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values

Returns:
A bool value indicating whether this instance is close to another.

"""
result = True

for name, field in self._attrs_fields.items():
if not hasattr(self, name):
continue

result = result and field.allclose( # type: ignore[call-arg]
getattr(self, name), getattr(other, name), rel_tol=rel_tol, abs_tol=abs_tol
)

return result


def attr(
*,
Expand Down Expand Up @@ -292,6 +324,43 @@ def _get_origin_in_3_6(annotation: Any) -> Any:
_get_origin = _get_origin_in_3_6 if version_info < (3, 7) else _get_origin_in_3_7


def _get_allclose(annotation: Any) -> Callable[[Any, Any, float, float], bool]:
"""Get attr allclose methods by annotations.

AttrsMixin has three operating types which are classified by attr annotation.
1. builtin types, like str, int, None
2. tensorbay custom class, like tensorbay.label.Classification
3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D]

Arguments:
annotation: Type of the attr.

Returns:
The ``_allclose`` methods of the annotation.

"""
origin = _get_origin(annotation)
if isinstance(origin, type) and issubclass(origin, Sequence):
type_ = annotation.__args__[0]
return lambda self, other, rel_tol=1e-09, abs_tol=0.0: all( # type: ignore[misc]
_get_allclose(type_)(i, j, rel_tol=rel_tol, abs_tol=abs_tol) # type: ignore[call-arg]
for i, j in zip(self, other)
)

type_ = annotation

mod = getattr(type_, "__module__", None)
if mod in _BUILTINS:
if type_ in (int, float):
return math_isclose # type: ignore[return-value]
return _eq_allclose # type: ignore[return-value]
return type_._allclose # type: ignore[no-any-return] # pylint: disable=protected-access


def _eq_allclose(object_1: object, object_2: object, **_: float) -> bool:
return object_1 == object_2


def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
"""Get attr operating methods by annotations.

Expand Down
48 changes: 48 additions & 0 deletions tensorbay/utility/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Callable, DefaultDict, Optional, Sequence, Tuple, Type, TypeVar, Union

import numpy as np
from typing_extensions import Protocol

_T = TypeVar("_T")
_Callable = TypeVar("_Callable", bound=Callable[..., Any])
Expand All @@ -45,6 +46,53 @@ def common_loads(object_class: Type[_T], contents: Any) -> _T:
return obj


class _A(Protocol): # pylint: disable=too-few-public-methods
def _allclose(self, other: "_A", *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Tell if all the data is close to the other object.

Arguments:
other: The other object.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values
"""


def allclose(object_1: _A, object_2: _A, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Determine whether object_1 is close to object_2 in value.

Arguments:
object_1: The first object to compare.
object_2: The second object to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values

Returns:
A bool value indicating whether object_1 is close to object_2.

"""
try:
# pylint: disable=protected-access
if issubclass(object_1.__class__, object_2.__class__) and hasattr(object_1, "_allclose"):
print(1)
return object_1._allclose(object_2, rel_tol=rel_tol, abs_tol=abs_tol)
if issubclass(object_2.__class__, object_1.__class__) and hasattr(object_2, "_allclose"):
print(2)
return object_2._allclose(object_1, rel_tol=rel_tol, abs_tol=abs_tol)

if hasattr(object_1, "_allclose"):
print(3)
return object_1._allclose(object_2, rel_tol=rel_tol, abs_tol=abs_tol)
print(4)
return object_2._allclose(object_1, rel_tol=rel_tol, abs_tol=abs_tol)
except Exception: # pylint: disable=broad-except
print(5)
return False


class EqMixin: # pylint: disable=too-few-public-methods
"""A mixin class to support __eq__() method.

Expand Down