From 43cb39c7bda912080d9b2dbc52bd2ef7e0e50ed5 Mon Sep 17 00:00:00 2001 From: Silvio Traversaro Date: Sun, 10 Mar 2024 15:03:15 +0100 Subject: [PATCH 1/2] Use list instead of npt.NDArray for body static attribute in GroundContact --- src/jaxsim/api/ode.py | 8 ++++++-- src/jaxsim/physics/algos/soft_contacts.py | 6 +++--- src/jaxsim/physics/algos/terrain.py | 10 ++++------ src/jaxsim/physics/model/ground_contact.py | 10 ++++------ src/jaxsim/simulation/ode.py | 2 +- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index f0cef275c..d3cada9c9 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -123,7 +123,7 @@ def system_velocity_dynamics( # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}. ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float) - if model.physics_model.gc.body.size > 0: + if len(model.physics_model.gc.body) > 0: # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data) @@ -140,7 +140,11 @@ def system_velocity_dynamics( # we don't need any coordinate transformation. W_f_Li_terrain = jax.vmap( lambda nc: ( - jnp.vstack(jnp.equal(model.physics_model.gc.body, nc).astype(int)) + jnp.vstack( + jnp.equal( + np.array(model.physics_model.gc.body, dtype=int), nc + ).astype(int) + ) * W_f_Ci ).sum(axis=0) )(jnp.arange(model.number_of_links())) diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index 10fcf55fb..02ebda317 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -58,7 +58,7 @@ def build_from_physics_model( return SoftContactsState.build( tangential_deformation=tangential_deformation, - number_of_collidable_points=physics_model.gc.body.size, + number_of_collidable_points=len(physics_model.gc.body), ) @staticmethod @@ -95,7 +95,7 @@ def valid( return check_valid_shape( what="tangential_deformation", shape=self.tangential_deformation.shape, - expected_shape=(3, physics_model.gc.body.size), + expected_shape=(3, len(physics_model.gc.body)), valid=True, ) @@ -237,7 +237,7 @@ def process_point_kinematics( # Process all the collidable points in parallel W_p_Ci, CW_v_WC = jax.vmap(process_point_kinematics)( - model.gc.point.T, model.gc.body + model.gc.point.T, np.array(model.gc.body, dtype=int) ) return W_p_Ci.transpose(), CW_v_WC.transpose() diff --git a/src/jaxsim/physics/algos/terrain.py b/src/jaxsim/physics/algos/terrain.py index 07460cc20..df0a0a402 100644 --- a/src/jaxsim/physics/algos/terrain.py +++ b/src/jaxsim/physics/algos/terrain.py @@ -46,23 +46,21 @@ def height(self, x: float, y: float) -> float: @jax_dataclasses.pytree_dataclass class PlaneTerrain(Terrain): - plane_normal: jtp.Vector = jax_dataclasses.field( - default_factory=lambda: jnp.array([0, 0, 1.0]) - ) + plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0]) @staticmethod - def build(plane_normal: jtp.Vector) -> "PlaneTerrain": + def build(plane_normal: list) -> "PlaneTerrain": """ Create a PlaneTerrain instance with a specified plane normal vector. Args: - plane_normal (jtp.Vector): The normal vector of the terrain plane. + plane_normal (list): The normal vector of the terrain plane. Returns: PlaneTerrain: A PlaneTerrain instance. """ - return PlaneTerrain(plane_normal=jnp.array(plane_normal, dtype=float)) + return PlaneTerrain(plane_normal=plane_normal) def height(self, x: float, y: float) -> float: """ diff --git a/src/jaxsim/physics/model/ground_contact.py b/src/jaxsim/physics/model/ground_contact.py index 4776cac71..beea198de 100644 --- a/src/jaxsim/physics/model/ground_contact.py +++ b/src/jaxsim/physics/model/ground_contact.py @@ -23,9 +23,7 @@ class GroundContact: """ point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([])) - body: Static[npt.NDArray] = dataclasses.field( - default_factory=lambda: np.array([], dtype=int) - ) + body: Static[list] = dataclasses.field(default_factory=lambda: []) @staticmethod def build_from( @@ -42,9 +40,9 @@ def build_from( # Build the GroundContact attributes points = jnp.vstack([cp.position for cp in collidable_points]).T - link_index_of_points = np.array( - [links_dict[cp.parent_link.name].index for cp in collidable_points] - ) + link_index_of_points = [ + links_dict[cp.parent_link.name].index for cp in collidable_points + ] # Build the object gc = GroundContact(point=points, body=link_index_of_points) diff --git a/src/jaxsim/simulation/ode.py b/src/jaxsim/simulation/ode.py index 06063f094..3dd9a429b 100644 --- a/src/jaxsim/simulation/ode.py +++ b/src/jaxsim/simulation/ode.py @@ -123,7 +123,7 @@ def dx_dt( ode_state.soft_contacts.tangential_deformation ) - if physics_model.gc.body.size > 0: + if len(physics_model.gc.body) > 0: ( contact_forces_links, tangential_deformation_dot, From 90f861c1bded8c1858662bed67962f7a442aa3a0 Mon Sep 17 00:00:00 2001 From: Silvio Traversaro Date: Sun, 10 Mar 2024 15:37:17 +0100 Subject: [PATCH 2/2] Add explicit __eq__ operators for classes that have ndarray attributes and could be used as Static jax_dataclasses attributes --- src/jaxsim/parsers/descriptions/collision.py | 14 ++++++++++++++ src/jaxsim/parsers/descriptions/link.py | 11 +++++++++++ src/jaxsim/parsers/kinematic_graph.py | 5 +++++ 3 files changed, 30 insertions(+) diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 7b4f8d184..289ab70be 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -50,6 +50,14 @@ def change_link( enabled=self.enabled, ) + def __eq__(self, other): + retval = ( + self.parent_link == other.parent_link + and (self.position == other.position).all() + and self.enabled == other.enabled + ) + return retval + def __str__(self): return ( f"{self.__class__.__name__}(" @@ -93,6 +101,9 @@ class BoxCollision(CollisionShape): center: npt.NDArray + def __eq__(self, other): + return (self.center == other.center).all() and super().__eq__(other) + @dataclasses.dataclass class SphereCollision(CollisionShape): @@ -105,3 +116,6 @@ class SphereCollision(CollisionShape): """ center: npt.NDArray + + def __eq__(self, other): + return (self.center == other.center).all() and super().__eq__(other) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index f5be265b2..d2912d7f0 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -38,6 +38,17 @@ class LinkDescription(JaxsimDataclass): def __hash__(self) -> int: return hash(self.__repr__()) + def __eq__(self, other) -> bool: + return ( + self.name == other.name + and self.mass == other.mass + and (self.inertia == other.inertia).all() + and self.index == other.index + and self.parent == other.parent + and (self.pose == other.pose).all() + and self.children == other.children + ) + @property def name_and_index(self) -> str: """ diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 9a4f8a43e..7b3021f47 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -34,6 +34,11 @@ class RootPose(NamedTuple): root_position: npt.NDArray = np.zeros(3) root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0]) + def __eq__(self, other): + return (self.root_position == other.root_position).all() and ( + self.root_quaternion == other.root_quaternion + ).all() + @dataclasses.dataclass(frozen=True) class KinematicGraph: