Skip to content

Commit

Permalink
Merge pull request #105 from traversaro/fix103
Browse files Browse the repository at this point in the history
Avoid to use as Static attributes classes that do not have a __eq__ method that returns a scalar bool
  • Loading branch information
diegoferigo authored Mar 11, 2024
2 parents 839ecd9 + 90f861c commit 4fd2032
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 18 deletions.
8 changes: 6 additions & 2 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def system_velocity_dynamics(
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
= jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)

if model.physics_model.gc.body.size > 0:
if len(model.physics_model.gc.body) > 0:
# Compute the position and linear velocities (mixed representation) of
# all collidable points belonging to the robot.
W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data)
Expand All @@ -140,7 +140,11 @@ def system_velocity_dynamics(
# we don't need any coordinate transformation.
W_f_Li_terrain = jax.vmap(
lambda nc: (
jnp.vstack(jnp.equal(model.physics_model.gc.body, nc).astype(int))
jnp.vstack(
jnp.equal(
np.array(model.physics_model.gc.body, dtype=int), nc
).astype(int)
)
* W_f_Ci
).sum(axis=0)
)(jnp.arange(model.number_of_links()))
Expand Down
14 changes: 14 additions & 0 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def change_link(
enabled=self.enabled,
)

def __eq__(self, other):
retval = (
self.parent_link == other.parent_link
and (self.position == other.position).all()
and self.enabled == other.enabled
)
return retval

def __str__(self):
return (
f"{self.__class__.__name__}("
Expand Down Expand Up @@ -93,6 +101,9 @@ class BoxCollision(CollisionShape):

center: npt.NDArray

def __eq__(self, other):
return (self.center == other.center).all() and super().__eq__(other)


@dataclasses.dataclass
class SphereCollision(CollisionShape):
Expand All @@ -105,3 +116,6 @@ class SphereCollision(CollisionShape):
"""

center: npt.NDArray

def __eq__(self, other):
return (self.center == other.center).all() and super().__eq__(other)
11 changes: 11 additions & 0 deletions src/jaxsim/parsers/descriptions/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ class LinkDescription(JaxsimDataclass):
def __hash__(self) -> int:
return hash(self.__repr__())

def __eq__(self, other) -> bool:
return (
self.name == other.name
and self.mass == other.mass
and (self.inertia == other.inertia).all()
and self.index == other.index
and self.parent == other.parent
and (self.pose == other.pose).all()
and self.children == other.children
)

@property
def name_and_index(self) -> str:
"""
Expand Down
5 changes: 5 additions & 0 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class RootPose(NamedTuple):
root_position: npt.NDArray = np.zeros(3)
root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])

def __eq__(self, other):
return (self.root_position == other.root_position).all() and (
self.root_quaternion == other.root_quaternion
).all()


@dataclasses.dataclass(frozen=True)
class KinematicGraph:
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/physics/algos/soft_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def build_from_physics_model(

return SoftContactsState.build(
tangential_deformation=tangential_deformation,
number_of_collidable_points=physics_model.gc.body.size,
number_of_collidable_points=len(physics_model.gc.body),
)

@staticmethod
Expand Down Expand Up @@ -95,7 +95,7 @@ def valid(
return check_valid_shape(
what="tangential_deformation",
shape=self.tangential_deformation.shape,
expected_shape=(3, physics_model.gc.body.size),
expected_shape=(3, len(physics_model.gc.body)),
valid=True,
)

Expand Down Expand Up @@ -237,7 +237,7 @@ def process_point_kinematics(

# Process all the collidable points in parallel
W_p_Ci, CW_v_WC = jax.vmap(process_point_kinematics)(
model.gc.point.T, model.gc.body
model.gc.point.T, np.array(model.gc.body, dtype=int)
)

return W_p_Ci.transpose(), CW_v_WC.transpose()
Expand Down
10 changes: 4 additions & 6 deletions src/jaxsim/physics/algos/terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,21 @@ def height(self, x: float, y: float) -> float:

@jax_dataclasses.pytree_dataclass
class PlaneTerrain(Terrain):
plane_normal: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.array([0, 0, 1.0])
)
plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0])

@staticmethod
def build(plane_normal: jtp.Vector) -> "PlaneTerrain":
def build(plane_normal: list) -> "PlaneTerrain":
"""
Create a PlaneTerrain instance with a specified plane normal vector.
Args:
plane_normal (jtp.Vector): The normal vector of the terrain plane.
plane_normal (list): The normal vector of the terrain plane.
Returns:
PlaneTerrain: A PlaneTerrain instance.
"""

return PlaneTerrain(plane_normal=jnp.array(plane_normal, dtype=float))
return PlaneTerrain(plane_normal=plane_normal)

def height(self, x: float, y: float) -> float:
"""
Expand Down
10 changes: 4 additions & 6 deletions src/jaxsim/physics/model/ground_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ class GroundContact:
"""

point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([]))
body: Static[npt.NDArray] = dataclasses.field(
default_factory=lambda: np.array([], dtype=int)
)
body: Static[list] = dataclasses.field(default_factory=lambda: [])

@staticmethod
def build_from(
Expand All @@ -42,9 +40,9 @@ def build_from(

# Build the GroundContact attributes
points = jnp.vstack([cp.position for cp in collidable_points]).T
link_index_of_points = np.array(
[links_dict[cp.parent_link.name].index for cp in collidable_points]
)
link_index_of_points = [
links_dict[cp.parent_link.name].index for cp in collidable_points
]

# Build the object
gc = GroundContact(point=points, body=link_index_of_points)
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/simulation/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def dx_dt(
ode_state.soft_contacts.tangential_deformation
)

if physics_model.gc.body.size > 0:
if len(physics_model.gc.body) > 0:
(
contact_forces_links,
tangential_deformation_dot,
Expand Down

0 comments on commit 4fd2032

Please sign in to comment.