Skip to content

Commit

Permalink
Prefer jaxsim.math operations to jaxlie
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 27, 2024
1 parent 4d2fd87 commit 08e10de
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 58 deletions.
7 changes: 3 additions & 4 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.math
Expand Down Expand Up @@ -28,7 +27,7 @@ def com_position(

W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_B = data.base_transform()
B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)

def B_p̃_LCoM(i) -> jtp.Vector:
m = js.link.mass(model=model, link_index=i)
Expand Down Expand Up @@ -179,9 +178,9 @@ def locked_centroidal_spatial_inertia(
case _:
raise ValueError(data.velocity_representation)

B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G)
B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G

B_Xv_G = B_H_G.adjoint()
B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
G_Xf_B = B_Xv_G.transpose()

return G_Xf_B @ B_Mbb_B @ B_Xv_G
Expand Down
18 changes: 9 additions & 9 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.utils import JaxsimDataclass, Mutability

try:
Expand Down Expand Up @@ -122,11 +122,11 @@ def inertial_to_other_representation(
case VelRepr.Body:

if not is_force:
O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint()
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
O_array = O_Xv_W @ W_array

else:
O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T
O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
O_array = O_Xf_W @ W_array

return O_array
Expand All @@ -136,11 +136,11 @@ def inertial_to_other_representation(
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)

if not is_force:
OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint()
OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
OW_array = OW_Xv_W @ W_array

else:
OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose()
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
OW_array = OW_Xf_W @ W_array

return OW_array
Expand Down Expand Up @@ -190,11 +190,11 @@ def other_representation_to_inertial(
O_array = array

if not is_force:
W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint()
W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
W_array = W_Xv_O @ O_array

else:
W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T
W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
W_array = W_Xf_O @ O_array

return W_array
Expand All @@ -205,11 +205,11 @@ def other_representation_to_inertial(
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)

if not is_force:
W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint()
W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
W_array = W_Xv_BW @ BW_array

else:
W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T
W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
W_array = W_Xf_BW @ BW_array

return W_array
Expand Down
13 changes: 6 additions & 7 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.math import Adjoint, Transform

from .common import VelRepr

Expand Down Expand Up @@ -232,25 +231,25 @@ def jacobian(
match output_vel_repr:
case VelRepr.Inertial:
W_H_L = js.link.transform(model=model, data=data, link_index=L)
W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
W_X_L = Adjoint.from_transform(transform=W_H_L)
W_J_WL = W_X_L @ L_J_WL
O_J_WL_I = W_J_WL

case VelRepr.Body:
W_H_L = js.link.transform(model=model, data=data, link_index=L)
W_H_F = transform(model=model, data=data, frame_index=frame_index)
F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
F_X_L = jaxlie.SE3.from_matrix(F_H_L).adjoint()
F_H_L = Transform.inverse(W_H_F) @ W_H_L
F_X_L = Adjoint.from_transform(transform=F_H_L)
F_J_WL = F_X_L @ L_J_WL
O_J_WL_I = F_J_WL

case VelRepr.Mixed:
W_H_L = js.link.transform(model=model, data=data, link_index=L)
W_H_F = transform(model=model, data=data, frame_index=frame_index)
F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
F_H_L = Transform.inverse(W_H_F) @ W_H_L
FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3))
FW_H_L = FW_H_F @ F_H_L
FW_X_L = jaxlie.SE3.from_matrix(FW_H_L).adjoint()
FW_X_L = Adjoint.from_transform(transform=FW_H_L)
FW_J_WL = FW_X_L @ L_J_WL
O_J_WL_I = FW_J_WL

Expand Down
9 changes: 3 additions & 6 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import jax.lax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.math import Inertia, JointModel, supported_joint_motion
from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
from jaxsim.parsers.descriptions import JointDescription, ModelDescription
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass

Expand Down Expand Up @@ -432,11 +431,9 @@ def joint_transforms_and_motion_subspaces(
# Compute the overall transforms from the parent to the child of each joint by
# composing all the components of our joint model.
i_X_λ = jax.vmap(
lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: jaxlie.SE3.from_matrix(
λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i
lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform(
transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True
)
.inverse()
.adjoint()
)(λ_H_pre, pre_H_suc, suc_H_i)

return i_X_λ, S
Expand Down
12 changes: 6 additions & 6 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import jaxlie

import jaxsim.api as js
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.math import Adjoint

from .common import VelRepr

Expand Down Expand Up @@ -287,7 +287,7 @@ def jacobian(
match data.velocity_representation:
case VelRepr.Inertial:
W_H_B = data.base_transform()
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
B_X_W, jnp.eye(model.dofs())
)
Expand All @@ -298,7 +298,7 @@ def jacobian(
case VelRepr.Mixed:
W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
B_X_BW, jnp.eye(model.dofs())
)
Expand All @@ -312,11 +312,11 @@ def jacobian(
match output_vel_repr:
case VelRepr.Inertial:
W_H_B = data.base_transform()
W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
W_X_B = Adjoint.from_transform(transform=W_H_B)
O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I

case VelRepr.Body:
L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint()
L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)
L_J_WL_I = L_X_B @ B_J_WL_I
O_J_WL_I = L_J_WL_I

Expand All @@ -325,7 +325,7 @@ def jacobian(
W_H_L = W_H_B @ B_H_L
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint()
LW_X_B = Adjoint.from_transform(transform=LW_H_B)
LW_J_WL_I = LW_X_B @ B_J_WL_I
O_J_WL_I = LW_J_WL_I

Expand Down
39 changes: 21 additions & 18 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
import rod
from jax_dataclasses import Static

import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.math import Cross
from jaxsim.math import Adjoint, Cross, Transform
from jaxsim.parsers.descriptions import ModelDescription
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers

Expand Down Expand Up @@ -494,7 +493,7 @@ def generalized_free_floating_jacobian(
case VelRepr.Inertial:

W_H_B = data.base_transform()
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)

B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
B_X_W, jnp.eye(model.dofs())
Expand All @@ -508,7 +507,7 @@ def generalized_free_floating_jacobian(

W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)

B_J_full_WX_I = B_J_full_WX_BW = (
B_J_full_WX_B
Expand Down Expand Up @@ -716,7 +715,7 @@ def to_active(

# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)

match data.velocity_representation:
Expand Down Expand Up @@ -881,15 +880,17 @@ def free_floating_mass_matrix(

case VelRepr.Inertial:

B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
B_X_W = Adjoint.from_transform(
transform=data.base_transform(), inverse=True
)
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))

return invT.T @ M_body @ invT

case VelRepr.Mixed:

BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))

return invT.T @ M_body @ invT
Expand Down Expand Up @@ -1078,8 +1079,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
W_X_C = Adjoint.from_transform(transform=W_H_C)
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
C_v_WC = C_X_W @ W_v_WC

# In Mixed representation, we need to include a cross product in ℝ⁶.
Expand Down Expand Up @@ -1362,12 +1363,14 @@ def total_momentum_jacobian(
B_Jh = B_Jh_B

case VelRepr.Inertial:
B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
B_X_W = Adjoint.from_transform(
transform=Transform.inverse(data.base_transform())
)
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))

case VelRepr.Mixed:
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
B_X_BW = Adjoint.from_transform(transform=Transform.inverse(BW_H_B))
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))

case _:
Expand All @@ -1379,14 +1382,14 @@ def total_momentum_jacobian(

case VelRepr.Inertial:
W_H_B = data.base_transform()
B_Xv_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
W_Xf_B = B_Xv_W.T
W_Jh = W_Xf_B @ B_Jh
return W_Jh

case VelRepr.Mixed:
BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
B_Xv_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
BW_Xf_B = B_Xv_BW.T
BW_Jh = BW_Xf_B @ B_Jh
return BW_Jh
Expand Down Expand Up @@ -1450,7 +1453,7 @@ def average_velocity_jacobian(
W_p_CoM = js.com.com_position(model=model, data=data)

W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
W_X_GW = jaxlie.SE3.from_matrix(W_H_GW).adjoint()
W_X_GW = Adjoint.from_transform(transform=W_H_GW)

return W_X_GW @ GW_J

Expand All @@ -1462,7 +1465,7 @@ def average_velocity_jacobian(
B_R_W = data.base_orientation(dcm=True).transpose()

B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
B_X_GB = jaxlie.SE3.from_matrix(B_H_GB).adjoint()
B_X_GB = Adjoint.from_transform(transform=B_H_GB)

return B_X_GB @ GB_J

Expand All @@ -1473,7 +1476,7 @@ def average_velocity_jacobian(
W_p_CoM = js.com.com_position(model=model, data=data)

BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
BW_X_GW = jaxlie.SE3.from_matrix(BW_H_GW).adjoint()
BW_X_GW = Adjoint.from_transform(transform=BW_H_GW)

return BW_X_GW @ GW_J

Expand Down Expand Up @@ -1519,8 +1522,8 @@ def other_representation_to_inertial(
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
W_X_C = Adjoint.from_transform(transform=W_H_C)
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)

# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms

from .rotation import Rotation
from .transform import Transform


@jax_dataclasses.pytree_dataclass
Expand Down Expand Up @@ -162,7 +163,7 @@ def child_H_parent(
joint_index=joint_index, joint_position=joint_position
)

i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix()
i_Hi_λ = Transform.inverse(λ_Hi_i)

return i_Hi_λ, S

Expand Down
Loading

0 comments on commit 08e10de

Please sign in to comment.