Skip to content

Commit

Permalink
Merge pull request #156 from lorycontixd/add_mesh_support
Browse files Browse the repository at this point in the history
Co-authored-by: Lorenzo Conti <[email protected]>

commit de410bf6102772f7078abac0c5ae77333362a0d3
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 12:17:22 2024 +0200

    Added string magics on wrapping methods

    - Added string magics on mesh wrapping methods for easier logging
    - Added log to indicate number of extracted points
    - Added extra check on objectmapping method

commit 1533f7016a3c562961b625bbf86d3550201d83e8
Author: Lorenzo Conti <[email protected]>
Date:   Fri Jul 5 14:53:48 2024 +0200

    Moved from jax.numpy to numpy in PlaneTerrain __eq__ magic to bypass TracerError

commit e3f167a
Merge: 8fa9adc fe2616c
Author: Filippo Luca Ferretti <[email protected]>
Date:   Fri Nov 15 17:35:46 2024 +0100

    Merge pull request #156 from lorycontixd/add_mesh_support

    Add mesh support

commit fe2616c
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 17:18:07 2024 +0100

    Removed whitespaces

commit 2d46a70
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 17:16:08 2024 +0100

    Fixed minor commenting format

commit eff915d
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 17:14:35 2024 +0100

    Removed unused function

commit 1738a87
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 17:12:19 2024 +0100

    Removed unused dependency from pyprojecj

commit bb21a16
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 17:10:09 2024 +0100

    Precommit fix

commit 495b799
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 16:43:42 2024 +0100

    Apply suggestions from code review

    Co-authored-by: Filippo Luca Ferretti <[email protected]>

commit 90451e8
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 16:41:36 2024 +0100

    Removed extra search paths in ergocub model building

commit e76b0b9
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 16:16:55 2024 +0100

    Added int casting on mesh_enabled flag

commit 8a29349
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 15:27:02 2024 +0100

    Updated variable names

    Co-authored-by: Filippo Luca Ferretti <[email protected]>

commit 8d0380e
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 16:14:26 2024 +0100

    Added experimental feature warning for mesh parsing

commit ad30c29
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 15:23:55 2024 +0100

    Addressed reviews

    - Removed unused dependencies in pyproject.toml
    - Moved from warning to exception if passed mesh is empty

    Co-authored-by: Filippo Luca Ferretti <[email protected]>

commit 9238af5
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 12:10:38 2024 +0100

    Fixed error on array sorting and relative test

commit fec016a
Author: Lorenzo Conti <[email protected]>
Date:   Fri Nov 15 11:41:24 2024 +0100

    Implemented reviews

commit d2a5e89
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 12:17:22 2024 +0200

    Added string magics on wrapping methods

    - Added string magics on mesh wrapping methods for easier logging
    - Added log to indicate number of extracted points
    - Added extra check on objectmapping method

commit 98012e5
Author: Lorenzo Conti <[email protected]>
Date:   Thu Nov 14 12:02:43 2024 +0100

    Added docstrings to mesh wrapping algorithms

commit 148fcfe
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 12:17:22 2024 +0200

    Added string magics on wrapping methods

    - Added string magics on mesh wrapping methods for easier logging
    - Added log to indicate number of extracted points
    - Added extra check on objectmapping method

commit 4f6cf0a
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 12:11:56 2024 +0200

    Removed wrong point selection & added logs

    - Removed a line that would always set the extracted points to the vertices
    - Added a few debug lines using logger

commit f8ecf24
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 12:01:34 2024 +0200

    Removed leftover parameters on create_mesh_collision

commit 6acb19f
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 11:55:01 2024 +0200

    Run pre-commit

commit 72ce440
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 11:50:37 2024 +0200

    Restructured mesh mapping methods to follow inheritance

    - Redefined methods using classes
    - Adapted rod parser to new structure
    - Reimplemented tests with new structure

commit 2d07347
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 10:45:09 2024 +0200

    Renamed some parameters

    - Renamed some function paramaeters
    - Added a few tests
    TODO: migrate MeshMapping static class to inheritance structure

commit 8058b7c
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jul 18 10:25:51 2024 +0200

    New mesh wrapping algorithms with relative tests

    - New mesh wrapping algorithms (mesh decimation, object mapping, aap, select points over axis)
    - Implemented tests of above except first algorithm
    - Updated manifold3d dependency (used in object mapping)
    - Restructured meshes module

commit 22da2cd
Author: Lorenzo Conti <[email protected]>
Date:   Wed Jul 17 23:00:18 2024 +0200

    New mesh wrapping algorithms

    - Implemented AAP algorithm
    - Restructured collision parsing to accept the new algorithms
    - Wrote tests for AAP algorithm
    - Updated JaxSim dependencies

commit d434e44
Author: Lorenzo Conti <[email protected]>
Date:   Wed Jul 17 17:56:46 2024 +0200

    Implemented structure for new mesh wrapping algorithms

commit f9475b0
Author: Lorenzo Conti <[email protected]>
Date:   Wed Jul 17 17:55:25 2024 +0200

    First draft of new mesh wrapping algorithms

commit 5af51da
Author: Lorenzo Conti <[email protected]>
Date:   Fri Jul 5 16:08:34 2024 +0200

    Implemented initial reviews from #156

    - Restructured mesh_collision creation method
    - Removed unnecessary hash inheritance

commit ed51e85
Author: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date:   Fri Jul 5 12:55:11 2024 +0000

    [pre-commit.ci] auto fixes from pre-commit.com hooks

    for more information, see https://pre-commit.ci

commit cf00335
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jun 20 10:39:42 2024 +0200

    Removed unused import in conftest

commit 94c2f81
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jun 20 10:37:23 2024 +0200

    Fixed typo on logging message

commit d7fd1b3
Author: Lorenzo Conti <[email protected]>
Date:   Thu Jun 20 10:34:52 2024 +0200

    Removed unused lines in conftest

commit 07b8402
Author: Filippo Luca Ferretti <[email protected]>
Date:   Thu Jun 13 12:34:26 2024 +0200

    Update `__eq__` magic and type hints

commit bcf5e48
Author: Lorenzo Conti <[email protected]>
Date:   Tue May 21 14:59:23 2024 +0200

    Pre-commit

commit b0ddec1
Author: Lorenzo Conti <[email protected]>
Date:   Tue May 21 14:57:30 2024 +0200

    Address reviews

    - Remove leftover comments
    - MeshMappingMethods inherit from IntEnum instead of Enum
    - Added center comparison for MeshCollision object

commit 644ce43
Author: Lorenzo Conti <[email protected]>
Date:   Tue May 21 14:51:08 2024 +0200

    Implemented UniformSurfaceSampling for mesh point wrapping

commit b15ea55
Author: Lorenzo Conti <[email protected]>
Date:   Tue May 21 10:11:55 2024 +0200

    Moved mesh parsing logic inside mesh collision function

commit 2f25c0e
Author: Lorenzo Conti <[email protected]>
Date:   Tue May 21 10:10:49 2024 +0200

    Added trimesh dependecy for conda-forge

commit 46f7164
Author: Lorenzo Conti <[email protected]>
Date:   Mon May 20 18:11:36 2024 +0200

    Address to reviews:

    - Moved mesh parsing logic inside method for creating mesh collisions
    - Removed vs code settings
    - Removed empty mesh parsing test

commit 990ea06
Author: Lorenzo Conti <[email protected]>
Date:   Thu May 16 14:34:48 2024 +0200

    Added `networkx` as testing dependency

commit 2c5f78f
Author: Filippo Luca Ferretti <[email protected]>
Date:   Mon May 20 16:16:41 2024 +0200

    Skip loading empty meshes

    Co-authored-by: Lorenzo Conti <[email protected]>

commit 6e968fc
Author: Filippo Luca Ferretti <[email protected]>
Date:   Fri May 17 10:22:23 2024 +0200

    Use already existing env var to solve mesh URIs

commit 1663a1f
Author: Filippo Luca Ferretti <[email protected]>
Date:   Tue May 14 15:35:49 2024 +0200

    Format and lint

commit 5d16532
Author: Lorenzo Conti <[email protected]>
Date:   Mon May 13 19:42:47 2024 +0200

    Initial version of mesh support

commit d07e4d1
Author: Filippo Luca Ferretti <[email protected]>
Date:   Mon May 13 19:40:27 2024 +0200

    Set env var when parsing from `robot-descriptions`

commit 3936638
Author: Lorenzo Conti <[email protected]>
Date:   Fri Jul 5 14:53:48 2024 +0200

    Moved from jax.numpy to numpy in PlaneTerrain __eq__ magic to bypass TracerError
  • Loading branch information
flferretti committed Nov 15, 2024
1 parent 0574569 commit 09e3701
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 2 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pptree
- qpax
- rod >= 0.3.3
- trimesh
- typing_extensions # python<3.12
# ====================================
# Optional dependencies from setup.cfg
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
"qpax",
"rod >= 0.3.3",
"typing_extensions ; python_version < '3.12'",
"trimesh",
]

[project.optional-dependencies]
Expand All @@ -67,7 +68,7 @@ testing = [
"idyntree >= 12.2.1",
"pytest >=6.0",
"pytest-icdiff",
"robot-descriptions",
"robot-descriptions"
]
viz = [
"lxml",
Expand Down
8 changes: 7 additions & 1 deletion src/jaxsim/parsers/descriptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
from .collision import (
BoxCollision,
CollidablePoint,
CollisionShape,
MeshCollision,
SphereCollision,
)
from .joint import JointDescription, JointGenericAxis, JointType
from .link import LinkDescription
from .model import ModelDescription
19 changes: 19 additions & 0 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,22 @@ def __eq__(self, other: BoxCollision) -> bool:
return False

return hash(self) == hash(other)


@dataclasses.dataclass
class MeshCollision(CollisionShape):
center: jtp.VectorLike

def __hash__(self) -> int:
return hash(
(
hash(tuple(self.center.tolist())),
hash(self.collidable_points),
)
)

def __eq__(self, other: MeshCollision) -> bool:
if not isinstance(other, MeshCollision):
return False

return hash(self) == hash(other)
104 changes: 104 additions & 0 deletions src/jaxsim/parsers/rod/meshes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import numpy as np
import trimesh

VALID_AXIS = {"x": 0, "y": 1, "z": 2}


def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
"""
Extracts the vertices of a mesh as points.
"""
return mesh.vertices


def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:
"""
Extracts N random points from the surface of a mesh.
Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.
Returns:
The extracted points (N x 3 array).
"""

return mesh.sample(n)


def extract_points_uniform_surface_sampling(
mesh: trimesh.Trimesh, n: int
) -> np.ndarray:
"""
Extracts N uniformly sampled points from the surface of a mesh.
Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.
Returns:
The extracted points (N x 3 array).
"""

return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0]


def extract_points_select_points_over_axis(
mesh: trimesh.Trimesh, axis: str, direction: str, n: int
) -> np.ndarray:
"""
Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis.
Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower".
n: The number of points to extract.
Returns:
The extracted points (N x 3 array).
"""

dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]}
arr = mesh.vertices

# Sort rows lexicographically first, then columnar.
arr.sort(axis=0)
sorted_arr = arr[dirs[direction]]
return sorted_arr


def extract_points_aap(
mesh: trimesh.Trimesh,
axis: str,
upper: float | None = None,
lower: float | None = None,
) -> np.ndarray:
"""
Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.
Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
upper: The upper bound of the range.
lower: The lower bound of the range.
Returns:
The extracted points (N x 3 array).
Raises:
AssertionError: If the lower bound is greater than the upper bound.
"""

# Check bounds.
upper = upper if upper is not None else np.inf
lower = lower if lower is not None else -np.inf
assert lower < upper, "Invalid bounds for axis-aligned plane"

# Logic.
points = mesh.vertices[
(mesh.vertices[:, VALID_AXIS[axis]] >= lower)
& (mesh.vertices[:, VALID_AXIS[axis]] <= upper)
]

return points
12 changes: 12 additions & 0 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,18 @@ def extract_model_data(

collisions.append(sphere_collision)

if collision.geometry.mesh is not None and int(
os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0")
):
logging.warning("Mesh collision support is still experimental.")
mesh_collision = utils.create_mesh_collision(
collision=collision,
link_description=links_dict[link.name],
method=utils.meshes.extract_points_vertices,
)

collisions.append(mesh_collision)

return SDFData(
model_name=sdf_model.name,
link_descriptions=links,
Expand Down
53 changes: 53 additions & 0 deletions src/jaxsim/parsers/rod/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import os
import pathlib
from collections.abc import Callable
from typing import TypeVar

import numpy as np
import numpy.typing as npt
import rod
import trimesh
from rod.utils.resolve_uris import resolve_local_uri

import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Inertia
from jaxsim.parsers import descriptions
from jaxsim.parsers.rod import meshes

MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray])


def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
Expand Down Expand Up @@ -202,3 +211,47 @@ def fibonacci_sphere(samples: int) -> npt.NDArray:
return descriptions.SphereCollision(
collidable_points=collidable_points, center=center_wrt_link
)


def create_mesh_collision(
collision: rod.Collision,
link_description: descriptions.LinkDescription,
method: MeshMappingMethod = None,
) -> descriptions.MeshCollision:

file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
_file_type = file.suffix.replace(".", "")
mesh = trimesh.load_mesh(file, file_type=_file_type)

if mesh.is_empty:
raise RuntimeError(f"Failed to process '{file}' with trimesh")

mesh.apply_scale(collision.geometry.mesh.scale)
logging.info(
msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{_file_type}'"
)

if method is None:
method = meshes.VertexExtraction()
logging.debug("Using default Vertex Extraction method for mesh wrapping")
else:
logging.debug(f"Using method {method} for mesh wrapping")

points = method(mesh=mesh)
logging.debug(f"Extracted {len(points)} points from mesh")

W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4)

# Extract translation from transformation matrix
W_p_L = W_H_L[:3, 3]
mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L
collidable_points = [
descriptions.CollidablePoint(
parent_link=link_description,
position=point,
enabled=True,
)
for point in mesh_points_wrt_link
]

return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)
100 changes: 100 additions & 0 deletions tests/test_meshes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import trimesh

from jaxsim.parsers.rod import meshes


def test_mesh_wrapping_vertex_extraction():
"""
Test the vertex extraction method on different meshes.
1. A simple box
2. A sphere
"""

# Test 1: A simple box.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(
extents=[3.0, 3.0, 3.0],
)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)

# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)


def test_mesh_wrapping_aap():
"""
Test the AAP wrapping method on different meshes.
1. A simple box
1.1: Remove all points above x=0.0
1.2: Remove all points below y=0.0
2. A sphere
"""

# Test 1.1: Remove all points above x=0.0.
# The expected result is that the number of points is halved.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])
points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 0] > 0.0)

# Test 1.2: Remove all points below y=0.0.
# The expected result is that the number of points is halved.
points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 1] < 0.0)

# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
# Points are expected to be halved.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)

# Remove all points above y=0.0.
points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0)
assert all(points[:, 1] >= 0.0)
assert len(points) < len(mesh.vertices)


def test_mesh_wrapping_points_over_axis():
"""
Test the points over axis method on different meshes.
1. A simple box
1.1: Select 10 points from the lower end of the x-axis
1.2: Select 10 points from the higher end of the y-axis
2. A sphere
"""

# Test 1.1: Remove 10 points from the lower end of the x-axis.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="x", direction="lower", n=4
)
assert len(points) == 4
assert all(points[:, 0] < 0.0)

# Test 1.2: Select 10 points from the higher end of the y-axis.
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="y", direction="higher", n=4
)
assert len(points) == 4
assert all(points[:, 1] > 0.0)

# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
sphere_n_vertices = len(mesh.vertices)

# Select 10 points from the higher end of the z-axis.
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2
)
assert len(points) == sphere_n_vertices // 2
assert all(points[:, 2] >= 0.0)

0 comments on commit 09e3701

Please sign in to comment.