From 08e10de10a72f21e59b0b7c147fde29a6282166a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Jun 2024 17:15:59 +0200 Subject: [PATCH] Prefer `jaxsim.math` operations to `jaxlie` --- src/jaxsim/api/com.py | 7 ++--- src/jaxsim/api/common.py | 18 ++++++------ src/jaxsim/api/frame.py | 13 ++++----- src/jaxsim/api/kin_dyn_parameters.py | 9 ++---- src/jaxsim/api/link.py | 12 ++++---- src/jaxsim/api/model.py | 39 +++++++++++++------------ src/jaxsim/math/joint_model.py | 3 +- src/jaxsim/parsers/descriptions/link.py | 6 ++-- src/jaxsim/parsers/rod/utils.py | 7 ++--- 9 files changed, 56 insertions(+), 58 deletions(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index c9ec47c36..42383b5a9 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -1,6 +1,5 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.math @@ -28,7 +27,7 @@ def com_position( W_H_L = js.model.forward_kinematics(model=model, data=data) W_H_B = data.base_transform() - B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix() + B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) def B_p̃_LCoM(i) -> jtp.Vector: m = js.link.mass(model=model, link_index=i) @@ -179,9 +178,9 @@ def locked_centroidal_spatial_inertia( case _: raise ValueError(data.velocity_representation) - B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G) + B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G - B_Xv_G = B_H_G.adjoint() + B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G) G_Xf_B = B_Xv_G.transpose() return G_Xf_B @ B_Mbb_B @ B_Xv_G diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 0c9fcf4f5..9519b7f42 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -8,10 +8,10 @@ import jax import jax.numpy as jnp import jax_dataclasses -import jaxlie from jax_dataclasses import Static import jaxsim.typing as jtp +from jaxsim.math import Adjoint from jaxsim.utils import JaxsimDataclass, Mutability try: @@ -122,11 +122,11 @@ def inertial_to_other_representation( case VelRepr.Body: if not is_force: - O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint() + O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) O_array = O_Xv_W @ W_array else: - O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T + O_Xf_W = Adjoint.from_transform(transform=W_H_O).T O_array = O_Xf_W @ W_array return O_array @@ -136,11 +136,11 @@ def inertial_to_other_representation( W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: - OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint() + OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) OW_array = OW_Xv_W @ W_array else: - OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose() + OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T OW_array = OW_Xf_W @ W_array return OW_array @@ -190,11 +190,11 @@ def other_representation_to_inertial( O_array = array if not is_force: - W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint() + W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O) W_array = W_Xv_O @ O_array else: - W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T + W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T W_array = W_Xf_O @ O_array return W_array @@ -205,11 +205,11 @@ def other_representation_to_inertial( W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: - W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint() + W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW) W_array = W_Xv_BW @ BW_array else: - W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T + W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T W_array = W_Xf_BW @ BW_array return W_array diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index cc7943127..40259155a 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -3,12 +3,11 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js -import jaxsim.math import jaxsim.typing as jtp from jaxsim import exceptions +from jaxsim.math import Adjoint, Transform from .common import VelRepr @@ -232,25 +231,25 @@ def jacobian( match output_vel_repr: case VelRepr.Inertial: W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint() + W_X_L = Adjoint.from_transform(transform=W_H_L) W_J_WL = W_X_L @ L_J_WL O_J_WL_I = W_J_WL case VelRepr.Body: W_H_L = js.link.transform(model=model, data=data, link_index=L) W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L - F_X_L = jaxlie.SE3.from_matrix(F_H_L).adjoint() + F_H_L = Transform.inverse(W_H_F) @ W_H_L + F_X_L = Adjoint.from_transform(transform=F_H_L) F_J_WL = F_X_L @ L_J_WL O_J_WL_I = F_J_WL case VelRepr.Mixed: W_H_L = js.link.transform(model=model, data=data, link_index=L) W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L + F_H_L = Transform.inverse(W_H_F) @ W_H_L FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3)) FW_H_L = FW_H_F @ F_H_L - FW_X_L = jaxlie.SE3.from_matrix(FW_H_L).adjoint() + FW_X_L = Adjoint.from_transform(transform=FW_H_L) FW_J_WL = FW_X_L @ L_J_WL O_J_WL_I = FW_J_WL diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 9fa819405..3afa53ef0 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -5,11 +5,10 @@ import jax.lax import jax.numpy as jnp import jax_dataclasses -import jaxlie from jax_dataclasses import Static import jaxsim.typing as jtp -from jaxsim.math import Inertia, JointModel, supported_joint_motion +from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion from jaxsim.parsers.descriptions import JointDescription, ModelDescription from jaxsim.utils import HashedNumpyArray, JaxsimDataclass @@ -432,11 +431,9 @@ def joint_transforms_and_motion_subspaces( # Compute the overall transforms from the parent to the child of each joint by # composing all the components of our joint model. i_X_λ = jax.vmap( - lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: jaxlie.SE3.from_matrix( - λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i + lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform( + transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True ) - .inverse() - .adjoint() )(λ_H_pre, pre_H_suc, suc_H_i) return i_X_λ, S diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index c88970b36..ce38085e7 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -4,12 +4,12 @@ import jax import jax.numpy as jnp import jax.scipy.linalg -import jaxlie import jaxsim.api as js import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import exceptions +from jaxsim.math import Adjoint from .common import VelRepr @@ -287,7 +287,7 @@ def jacobian( match data.velocity_representation: case VelRepr.Inertial: W_H_B = data.base_transform() - B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( B_X_W, jnp.eye(model.dofs()) ) @@ -298,7 +298,7 @@ def jacobian( case VelRepr.Mixed: W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( B_X_BW, jnp.eye(model.dofs()) ) @@ -312,11 +312,11 @@ def jacobian( match output_vel_repr: case VelRepr.Inertial: W_H_B = data.base_transform() - W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint() + W_X_B = Adjoint.from_transform(transform=W_H_B) O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I case VelRepr.Body: - L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint() + L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) L_J_WL_I = L_X_B @ B_J_WL_I O_J_WL_I = L_J_WL_I @@ -325,7 +325,7 @@ def jacobian( W_H_L = W_H_B @ B_H_L LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) - LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint() + LW_X_B = Adjoint.from_transform(transform=LW_H_B) LW_J_WL_I = LW_X_B @ B_J_WL_I O_J_WL_I = LW_J_WL_I diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 0e398a966..0cdb6b916 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -9,14 +9,13 @@ import jax import jax.numpy as jnp import jax_dataclasses -import jaxlie import rod from jax_dataclasses import Static import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim.math import Cross +from jaxsim.math import Adjoint, Cross, Transform from jaxsim.parsers.descriptions import ModelDescription from jaxsim.utils import JaxsimDataclass, Mutability, wrappers @@ -494,7 +493,7 @@ def generalized_free_floating_jacobian( case VelRepr.Inertial: W_H_B = data.base_transform() - B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag( B_X_W, jnp.eye(model.dofs()) @@ -508,7 +507,7 @@ def generalized_free_floating_jacobian( W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_J_full_WX_I = B_J_full_WX_BW = ( B_J_full_WX_B @@ -716,7 +715,7 @@ def to_active( # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. - C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB) match data.velocity_representation: @@ -881,7 +880,9 @@ def free_floating_mass_matrix( case VelRepr.Inertial: - B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint() + B_X_W = Adjoint.from_transform( + transform=data.base_transform(), inverse=True + ) invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) return invT.T @ M_body @ invT @@ -889,7 +890,7 @@ def free_floating_mass_matrix( case VelRepr.Mixed: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) return invT.T @ M_body @ invT @@ -1078,8 +1079,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ - W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint() - C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() + W_X_C = Adjoint.from_transform(transform=W_H_C) + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) C_v_WC = C_X_W @ W_v_WC # In Mixed representation, we need to include a cross product in ℝ⁶. @@ -1362,12 +1363,14 @@ def total_momentum_jacobian( B_Jh = B_Jh_B case VelRepr.Inertial: - B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint() + B_X_W = Adjoint.from_transform( + transform=Transform.inverse(data.base_transform()) + ) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) case VelRepr.Mixed: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=Transform.inverse(BW_H_B)) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) case _: @@ -1379,14 +1382,14 @@ def total_momentum_jacobian( case VelRepr.Inertial: W_H_B = data.base_transform() - B_Xv_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint() + B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) W_Xf_B = B_Xv_W.T W_Jh = W_Xf_B @ B_Jh return W_Jh case VelRepr.Mixed: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_Xv_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) BW_Xf_B = B_Xv_BW.T BW_Jh = BW_Xf_B @ B_Jh return BW_Jh @@ -1450,7 +1453,7 @@ def average_velocity_jacobian( W_p_CoM = js.com.com_position(model=model, data=data) W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - W_X_GW = jaxlie.SE3.from_matrix(W_H_GW).adjoint() + W_X_GW = Adjoint.from_transform(transform=W_H_GW) return W_X_GW @ GW_J @@ -1462,7 +1465,7 @@ def average_velocity_jacobian( B_R_W = data.base_orientation(dcm=True).transpose() B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) - B_X_GB = jaxlie.SE3.from_matrix(B_H_GB).adjoint() + B_X_GB = Adjoint.from_transform(transform=B_H_GB) return B_X_GB @ GB_J @@ -1473,7 +1476,7 @@ def average_velocity_jacobian( W_p_CoM = js.com.com_position(model=model, data=data) BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) - BW_X_GW = jaxlie.SE3.from_matrix(BW_H_GW).adjoint() + BW_X_GW = Adjoint.from_transform(transform=BW_H_GW) return BW_X_GW @ GW_J @@ -1519,8 +1522,8 @@ def other_representation_to_inertial( expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ - W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint() - C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() + W_X_C = Adjoint.from_transform(transform=W_H_C) + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index e770cb42a..16f3914fe 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -11,6 +11,7 @@ from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms from .rotation import Rotation +from .transform import Transform @jax_dataclasses.pytree_dataclass @@ -162,7 +163,7 @@ def child_H_parent( joint_index=joint_index, joint_position=joint_position ) - i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix() + i_Hi_λ = Transform.inverse(λ_Hi_i) return i_Hi_λ, S diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 41f5399df..445ad7830 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -4,11 +4,11 @@ import jax.numpy as jnp import jax_dataclasses -import jaxlie import numpy as np from jax_dataclasses import Static import jaxsim.typing as jtp +from jaxsim.math import Adjoint, Transform from jaxsim.utils import JaxsimDataclass @@ -106,8 +106,8 @@ def lump_with( I_removed = link.inertia # Create the SE3 object. Note the inverse. - r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse() - r_X_l = r_H_l.adjoint() + r_H_l = Transform.inverse(lumped_H_removed) + r_X_l = Adjoint.from_transform(transform=r_H_l) # Move the inertia I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index d5ebbb15e..7ad1c26bb 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,12 +1,11 @@ import os -import jaxlie import numpy as np import numpy.typing as npt import rod import jaxsim.typing as jtp -from jaxsim.math import Inertia +from jaxsim.math import Adjoint, Inertia, Transform from jaxsim.parsers import descriptions @@ -50,8 +49,8 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4) # We need its inverse - CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse() - CoM_X_L = CoM_H_L.adjoint() + CoM_H_L = Transform.inverse(L_H_CoM) + CoM_X_L = Adjoint.from_transform(transform=CoM_H_L) # Express the CoM inertia matrix in the link frame L M_L = CoM_X_L.T @ M_CoM @ CoM_X_L