diff --git a/examples/PD_controller.ipynb b/examples/PD_controller.ipynb index 5da316c89..38dabee4d 100644 --- a/examples/PD_controller.ipynb +++ b/examples/PD_controller.ipynb @@ -20,7 +20,7 @@ "outputs": [], "source": [ "# @title Imports and setup\n", - "from IPython.display import clear_output, HTML, display\n", + "from IPython.display import clear_output\n", "import sys\n", "\n", "IS_COLAB = \"google.colab\" in sys.modules\n", @@ -100,6 +100,8 @@ "from jaxsim import integrators\n", "\n", "dt = 0.01\n", + "integration_time = 5.0\n", + "num_steps = int(integration_time / dt)\n", "\n", "model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_urdf_string, is_urdf=True\n", @@ -150,7 +152,7 @@ "source": [ "# @title Set up MuJoCo renderer\n", "\n", - "from jaxsim.mujoco import RodModelToMjcf, MujocoModelHelper, MujocoVideoRecorder\n", + "from jaxsim.mujoco import MujocoModelHelper, MujocoVideoRecorder\n", "from jaxsim.mujoco.loaders import UrdfToMjcf\n", "\n", "import os\n", @@ -227,7 +229,7 @@ "source": [ "import mediapy as media\n", "\n", - "for _ in range(500):\n", + "for _ in range(num_steps):\n", " data, integrator_state = js.model.step(\n", " dt=dt,\n", " model=model,\n", @@ -299,7 +301,7 @@ "metadata": {}, "outputs": [], "source": [ - "for _ in range(500):\n", + "for _ in range(num_steps):\n", " control_torques = pd_controller(\n", " data=data,\n", " q_d=jnp.array([0.0, 0.0]),\n", diff --git a/examples/Parallel_computing.ipynb b/examples/Parallel_computing.ipynb index 4bb8fa9d9..390c7b408 100644 --- a/examples/Parallel_computing.ipynb +++ b/examples/Parallel_computing.ipynb @@ -23,8 +23,6 @@ "# @title Imports and setup\n", "import sys\n", "\n", - "from IPython.display import HTML, clear_output, display\n", - "\n", "IS_COLAB = \"google.colab\" in sys.modules\n", "\n", "# Install JAX and Gazebo\n", @@ -40,15 +38,12 @@ "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", "\n", "import time\n", - "from typing import Dict, Tuple\n", "\n", "import jax\n", "import jax.numpy as jnp\n", - "import jax_dataclasses\n", "import rod\n", "from rod.builder.primitives import SphereBuilder\n", "\n", - "import jaxsim.typing as jtp\n", "from jaxsim import logging\n", "\n", "logging.set_logging_level(logging.LoggingLevel.INFO)\n", @@ -105,7 +100,7 @@ "from jaxsim import integrators\n", "\n", "dt = 0.001\n", - "integration_time = 1500\n", + "integration_time = 1.5 # seconds\n", "\n", "model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_sdf_string\n", @@ -129,7 +124,9 @@ "\n", "By default, in JaxSim a sphere primitive has 250 collision points. This can be modified by setting the `JAXSIM_COLLISION_SPHERE_POINTS` environment variable.\n", "\n", - "Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values." + "Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values. \n", + "\n", + "Eventually, you can specify the maximum penetration depth of the sphere into the terrain by passing `max_penetraion` to the `estimate_good_soft_contacts_parameters` function." ] }, { @@ -139,8 +136,10 @@ "outputs": [], "source": [ "data = data.replace(\n", - " soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n", - " model, number_of_active_collidable_points_steady_state=3\n", + " contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n", + " model=model,\n", + " number_of_active_collidable_points_steady_state=3,\n", + " max_penetration=None,\n", " )\n", ")" ] @@ -149,7 +148,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height." + "Let's create a position vector for a 4x4 grid. Every sphere will be placed at a different height." ] }, { @@ -169,14 +168,13 @@ "def grid(edge_len, envs_per_row):\n", " edge = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", " xx, yy = jnp.meshgrid(edge, edge)\n", - "\n", - " poses = [\n", - " [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]\n", - " for i in range(xx.shape[0])\n", - " for j in range(yy.shape[0])\n", - " ]\n", - "\n", - " return jnp.array(poses)\n", + " zz = 0.2 + 0.1 * (\n", + " jnp.arange(envs_per_row**2) % envs_per_row\n", + " + jnp.arange(envs_per_row**2) // envs_per_row\n", + " )\n", + " zz = zz.reshape(envs_per_row, envs_per_row)\n", + " poses = jnp.stack([xx, yy, zz], axis=-1).reshape(envs_per_row**2, 3)\n", + " return poses\n", "\n", "\n", "logging.info(f\"Simulating {envs_per_row**2} environments\")\n", @@ -200,11 +198,13 @@ "def simulate(\n", " data: js.data.JaxSimModelData, integrator_state: dict, pose: jnp.array\n", ") -> tuple:\n", - "\n", + " # Set the base position to the initial pose\n", " data = data.reset_base_position(base_position=pose)\n", + "\n", + " # Create a list to store the base position over time\n", " x_t_i = []\n", "\n", - " for _ in range(integration_time):\n", + " for _ in range(int(integration_time // dt)):\n", " data, integrator_state = js.model.step(\n", " dt=dt,\n", " model=model,\n", @@ -243,7 +243,7 @@ "# Run and time the simulation\n", "now = time.perf_counter()\n", "\n", - "x_t = simulate_vectorized(data, integrator_state, poses[:, 0])\n", + "x_t = simulate_vectorized(data, integrator_state, poses)\n", "\n", "comp_time = time.perf_counter() - now\n", "\n", @@ -251,7 +251,7 @@ " f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n", ")\n", "logging.info(\n", - " f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 *integration_time/comp_time):.2f}\"\n", + " f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 * integration_time / comp_time):.2f}\"\n", ")" ] }, diff --git a/pyproject.toml b/pyproject.toml index 7782fc678..266a1e053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,25 +41,26 @@ preview = true [tool.ruff.lint] # https://docs.astral.sh/ruff/rules/ +select = [ + "B", + "E", + "F", + "I", + "W", + "RUF", + "YTT", +] + ignore = [ "B008", # Function call in default argument "B024", # Abstract base class without abstract methods - "B904", # Raise without from inside exception - "B905", # Zip without explicit strict "E402", # Module level import not at top of file "E501", # Line too long "E731", # Do not assign a `lambda` expression, use a `def` "E741", # Ambiguous variable name "F841", # Local variable is assigned to but never used "I001", # Import block is unsorted or unformatted -] -select = [ - "B", - "E", - "F", - "I", - "W", - "YTT", + "RUF003", # Ambigous unicode character in comment ] [tool.ruff.lint.per-file-ignores] diff --git a/src/jaxsim/__init__.py b/src/jaxsim/__init__.py index 4f15e0ea5..f22281b61 100644 --- a/src/jaxsim/__init__.py +++ b/src/jaxsim/__init__.py @@ -33,17 +33,17 @@ def _is_editable() -> bool: import pathlib import site - # Get the ModuleSpec of jaxsim + # Get the ModuleSpec of jaxsim. jaxsim_spec = importlib.util.find_spec(name="jaxsim") # This can be None. If it's None, assume non-editable installation. if jaxsim_spec.origin is None: return False - # Get the folder containing the jaxsim package + # Get the folder containing the jaxsim package. jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent) - # The installation is editable if the package dir is not in any {site|dist}-packages + # The installation is editable if the package dir is not in any {site|dist}-packages. return jaxsim_package_dir not in site.getsitepackages() @@ -82,10 +82,10 @@ def _get_default_logging_level(env_var: str) -> logging.LoggingLevel: logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL")) -# Configure JAX +# Configure JAX. _jnp_options() -# Initialize the numpy print options +# Initialize the numpy print options. _np_options() del _jnp_options diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index c9ec47c36..42383b5a9 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -1,6 +1,5 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.math @@ -28,7 +27,7 @@ def com_position( W_H_L = js.model.forward_kinematics(model=model, data=data) W_H_B = data.base_transform() - B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix() + B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) def B_p̃_LCoM(i) -> jtp.Vector: m = js.link.mass(model=model, link_index=i) @@ -179,9 +178,9 @@ def locked_centroidal_spatial_inertia( case _: raise ValueError(data.velocity_representation) - B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G) + B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G - B_Xv_G = B_H_G.adjoint() + B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G) G_Xf_B = B_Xv_G.transpose() return G_Xf_B @ B_Mbb_B @ B_Xv_G diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 0c9fcf4f5..d8ba8dd7b 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -8,10 +8,10 @@ import jax import jax.numpy as jnp import jax_dataclasses -import jaxlie from jax_dataclasses import Static import jaxsim.typing as jtp +from jaxsim.math import Adjoint from jaxsim.utils import JaxsimDataclass, Mutability try: @@ -59,7 +59,7 @@ def switch_velocity_representation( try: - # First, we replace the velocity representation + # First, we replace the velocity representation. with self.mutable_context( mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=True, @@ -97,7 +97,7 @@ def inertial_to_other_representation( array: The 6D quantity to convert. other_representation: The representation to convert to. transform: - The `math:W \mathbf{H}_O` transform, where `math:O` is the + The :math:`W \mathbf{H}_O` transform, where :math:`O` is the reference frame of the other representation. is_force: Whether the quantity is a 6D force or a 6D velocity. @@ -122,11 +122,11 @@ def inertial_to_other_representation( case VelRepr.Body: if not is_force: - O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint() + O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) O_array = O_Xv_W @ W_array else: - O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T + O_Xf_W = Adjoint.from_transform(transform=W_H_O).T O_array = O_Xf_W @ W_array return O_array @@ -136,11 +136,11 @@ def inertial_to_other_representation( W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: - OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint() + OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) OW_array = OW_Xv_W @ W_array else: - OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose() + OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T OW_array = OW_Xf_W @ W_array return OW_array @@ -190,11 +190,11 @@ def other_representation_to_inertial( O_array = array if not is_force: - W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint() + W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O) W_array = W_Xv_O @ O_array else: - W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T + W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T W_array = W_Xf_O @ O_array return W_array @@ -205,11 +205,11 @@ def other_representation_to_inertial( W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: - W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint() + W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW) W_array = W_Xv_BW @ BW_array else: - W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T + W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T W_array = W_Xf_BW @ BW_array return W_array diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 3dce69d56..916c8ce51 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -8,7 +8,7 @@ import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsParams +from jaxsim.rbda.contacts.soft import SoftContactsParams from .common import VelRepr @@ -137,9 +137,17 @@ def collidable_point_dynamics( # all collidable points belonging to the robot. W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data) + # Import privately the soft contacts classes. + from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState + # Build the soft contact model. match model.contact_model: - case s if isinstance(s, SoftContacts): + + case SoftContacts(): + + assert isinstance(model.contact_model, SoftContacts) + assert isinstance(data.state.contact, SoftContactsState) + # Build the contact model. soft_contacts = SoftContacts( parameters=data.contacts_params, terrain=model.terrain @@ -337,7 +345,7 @@ def jacobian( The output velocity representation of the free-floating jacobian. Returns: - The stacked 6×(6+n) free-floating jacobians of the frames associated to the + The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the collidable points. Note: diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index c1ba29d23..1ef55158b 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -8,7 +8,6 @@ import jax.numpy as jnp import jax_dataclasses import jaxlie -import numpy as np import jaxsim.api as js import jaxsim.rbda @@ -390,7 +389,7 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix ).astype(float) @jax.jit - def base_transform(self) -> jtp.MatrixJax: + def base_transform(self) -> jtp.Matrix: """ Get the base transform. @@ -625,9 +624,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: W_p_B = base_pose[0:3, 3] - to_wxyz = np.array([3, 0, 1, 2]) - W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa - W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz] + W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3]) return self.reset_base_position(base_position=W_p_B).reset_base_quaternion( base_quaternion=W_Q_B @@ -815,7 +812,7 @@ def random_model_data( physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians( *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi) - ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])] + ).wxyz if model.number_of_joints() > 0: physics_model_state.joint_positions = js.joint.random_joint_positions( diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 0eb0af43c..097ab78c2 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -3,12 +3,11 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js -import jaxsim.math import jaxsim.typing as jtp from jaxsim import exceptions +from jaxsim.math import Adjoint, Transform from .common import VelRepr @@ -189,7 +188,7 @@ def jacobian( frame_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: - """ + r""" Compute the free-floating jacobian of the frame. Args: @@ -200,7 +199,7 @@ def jacobian( The output velocity representation of the free-floating jacobian. Returns: - The 6×(6+n) free-floating jacobian of the frame. + The :math:`6 \times (6+n)` free-floating jacobian of the frame. Note: The input representation of the free-floating jacobian is the active @@ -228,29 +227,29 @@ def jacobian( model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body ) - # Adjust the output representation + # Adjust the output representation. match output_vel_repr: case VelRepr.Inertial: W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint() + W_X_L = Adjoint.from_transform(transform=W_H_L) W_J_WL = W_X_L @ L_J_WL O_J_WL_I = W_J_WL case VelRepr.Body: W_H_L = js.link.transform(model=model, data=data, link_index=L) W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L - F_X_L = jaxlie.SE3.from_matrix(F_H_L).adjoint() + F_H_L = Transform.inverse(W_H_F) @ W_H_L + F_X_L = Adjoint.from_transform(transform=F_H_L) F_J_WL = F_X_L @ L_J_WL O_J_WL_I = F_J_WL case VelRepr.Mixed: W_H_L = js.link.transform(model=model, data=data, link_index=L) W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L + F_H_L = Transform.inverse(W_H_F) @ W_H_L FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3)) FW_H_L = FW_H_F @ F_H_L - FW_X_L = jaxlie.SE3.from_matrix(FW_H_L).adjoint() + FW_X_L = Adjoint.from_transform(transform=FW_H_L) FW_J_WL = FW_X_L @ L_J_WL O_J_WL_I = FW_J_WL diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index f9ffbc0a7..3afa53ef0 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -5,11 +5,10 @@ import jax.lax import jax.numpy as jnp import jax_dataclasses -import jaxlie from jax_dataclasses import Static import jaxsim.typing as jtp -from jaxsim.math import Inertia, JointModel, supported_joint_motion +from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion from jaxsim.parsers.descriptions import JointDescription, ModelDescription from jaxsim.utils import HashedNumpyArray, JaxsimDataclass @@ -168,7 +167,7 @@ def build(model_description: ModelDescription) -> KynDynParameters: for link in ordered_links if link.parent is not None } - parent_array = jnp.array([-1] + list(parent_array_dict.values()), dtype=int) + parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int) # Instead of building the support parent array κ(i) for each link of the model, # that has a variable length depending on the number of links connecting the @@ -432,11 +431,9 @@ def joint_transforms_and_motion_subspaces( # Compute the overall transforms from the parent to the child of each joint by # composing all the components of our joint model. i_X_λ = jax.vmap( - lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: jaxlie.SE3.from_matrix( - λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i + lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform( + transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True ) - .inverse() - .adjoint() )(λ_H_pre, pre_H_suc, suc_H_i) return i_X_λ, S @@ -466,12 +463,12 @@ def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameter def set_link_inertia( self, link_index: int, inertia: jtp.MatrixLike ) -> KynDynParameters: - """ + r""" Set the inertia tensor of a link. Args: link_index: The index of the link. - inertia: The 3×3 inertia tensor of the link. + inertia: The :math:`3 \times 3` inertia tensor of the link. Returns: The updated kinematic and dynamic parameters of the model. @@ -569,7 +566,7 @@ class LinkParameters(JaxsimDataclass): index: The index of the link. mass: The mass of the link. inertia_elements: - The unique elements of the 3×3 inertia tensor of the link. + The unique elements of the :math:`3 \times 3` inertia tensor of the link. center_of_mass: The translation :math:`{}^L \mathbf{p}_{\text{CoM}}` between the origin of the link frame and the link's center of mass, expressed in the @@ -588,12 +585,12 @@ class LinkParameters(JaxsimDataclass): @staticmethod def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters: - """ - Build a LinkParameters object from a 6×6 spatial inertia matrix. + r""" + Build a LinkParameters object from a :math:`6 \times 6` spatial inertia matrix. Args: index: The index of the link. - M: The 6×6 spatial inertia matrix of the link. + M: The :math:`6 \times 6` spatial inertia matrix of the link. Returns: The LinkParameters object. @@ -616,13 +613,13 @@ def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParamet def build_from_inertial_parameters( index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike ) -> LinkParameters: - """ + r""" Build a LinkParameters object from the inertial parameters of a link. Args: index: The index of the link. m: The mass of the link. - I: The 3×3 inertia tensor of the link. + I: The :math:`3 \times 3` inertia tensor of the link. c: The translation between the link frame and the link's center of mass. Returns: @@ -676,14 +673,14 @@ def flat_parameters(params: LinkParameters) -> jtp.Vector: @staticmethod def inertia_tensor(params: LinkParameters) -> jtp.Matrix: - """ - Return the 3×3 inertia tensor of a link. + r""" + Return the :math:`3 \times 3` inertia tensor of a link. Args: params: The link parameters. Returns: - The 3×3 inertia tensor of the link. + The :math:`3 \times 3` inertia tensor of the link. """ return LinkParameters.unflatten_inertia_tensor( @@ -692,14 +689,14 @@ def inertia_tensor(params: LinkParameters) -> jtp.Matrix: @staticmethod def spatial_inertia(params: LinkParameters) -> jtp.Matrix: - """ - Return the 6×6 spatial inertia matrix of a link. + r""" + Return the :math:`6 \times 6` spatial inertia matrix of a link. Args: params: The link parameters. Returns: - The 6×6 spatial inertia matrix of the link. + The :math:`6 \times 6` spatial inertia matrix of the link. """ return Inertia.to_sixd( @@ -710,11 +707,11 @@ def spatial_inertia(params: LinkParameters) -> jtp.Matrix: @staticmethod def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector: - """ - Flatten a 3×3 inertia tensor into a vector of unique elements. + r""" + Flatten a :math:`3 \times 3` inertia tensor into a vector of unique elements. Args: - I: The 3×3 inertia tensor. + I: The :math:`3 \times 3` inertia tensor. Returns: The vector of unique elements of the inertia tensor. @@ -724,14 +721,14 @@ def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector: @staticmethod def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix: - """ - Unflatten a vector of unique elements into a 3×3 inertia tensor. + r""" + Unflatten a vector of unique elements into a :math:`3 \times 3` inertia tensor. Args: inertia_elements: The vector of unique elements of the inertia tensor. Returns: - The 3×3 inertia tensor. + The :math:`3 \times 3` inertia tensor. """ I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze()) @@ -792,7 +789,7 @@ def build_from(model_description: ModelDescription) -> ContactParameters: ) # Build the ContactParameters object. - cp = ContactParameters(point=points, body=link_index_of_points) # noqa + cp = ContactParameters(point=points, body=link_index_of_points) assert cp.point.shape[1] == 3, cp.point.shape[1] assert cp.point.shape[0] == len(cp.body), cp.point.shape[0] diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index c77a5acdb..ce38085e7 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -4,12 +4,12 @@ import jax import jax.numpy as jnp import jax.scipy.linalg -import jaxlie import jaxsim.api as js import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import exceptions +from jaxsim.math import Adjoint from .common import VelRepr @@ -134,7 +134,7 @@ def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float: def spatial_inertia( model: js.model.JaxSimModel, *, link_index: jtp.IntLike ) -> jtp.Matrix: - """ + r""" Compute the 6D spatial inertial of the link. Args: @@ -142,7 +142,7 @@ def spatial_inertia( link_index: The index of the link. Returns: - The 6×6 matrix representing the spatial inertia of the link expressed in + The :math:`6 \times 6` matrix representing the spatial inertia of the link expressed in the link frame (body-fixed representation). """ @@ -243,7 +243,7 @@ def jacobian( link_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: - """ + r""" Compute the free-floating jacobian of the link. Args: @@ -254,7 +254,7 @@ def jacobian( The output velocity representation of the free-floating jacobian. Returns: - The 6×(6+n) free-floating jacobian of the link. + The :math:`6 \times (6+n)` free-floating jacobian of the link. Note: The input representation of the free-floating jacobian is the active @@ -287,7 +287,7 @@ def jacobian( match data.velocity_representation: case VelRepr.Inertial: W_H_B = data.base_transform() - B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( B_X_W, jnp.eye(model.dofs()) ) @@ -298,7 +298,7 @@ def jacobian( case VelRepr.Mixed: W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( B_X_BW, jnp.eye(model.dofs()) ) @@ -312,11 +312,11 @@ def jacobian( match output_vel_repr: case VelRepr.Inertial: W_H_B = data.base_transform() - W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint() + W_X_B = Adjoint.from_transform(transform=W_H_B) O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I case VelRepr.Body: - L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint() + L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) L_J_WL_I = L_X_B @ B_J_WL_I O_J_WL_I = L_J_WL_I @@ -325,7 +325,7 @@ def jacobian( W_H_L = W_H_B @ B_H_L LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) - LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint() + LW_X_B = Adjoint.from_transform(transform=LW_H_B) LW_J_WL_I = LW_X_B @ B_J_WL_I O_J_WL_I = LW_J_WL_I @@ -393,7 +393,7 @@ def jacobian_derivative( link_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: - """ + r""" Compute the derivative of the free-floating jacobian of the link. Args: @@ -404,7 +404,7 @@ def jacobian_derivative( The output velocity representation of the free-floating jacobian derivative. Returns: - The derivative of the 6×(6+n) free-floating jacobian of the link. + The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the link. Note: The input representation of the free-floating jacobian derivative is the active diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 5bb45fa9d..072b99859 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -9,14 +9,14 @@ import jax import jax.numpy as jnp import jax_dataclasses -import jaxlie import rod from jax_dataclasses import Static import jaxsim.api as js -import jaxsim.parsers.descriptions +import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim.math import Cross +from jaxsim.math import Adjoint, Cross +from jaxsim.parsers.descriptions import ModelDescription from jaxsim.utils import JaxsimDataclass, Mutability, wrappers from .common import VelRepr @@ -46,12 +46,12 @@ class JaxSimModel(JaxsimDataclass): default=None, repr=False ) - _description: Static[ - wrappers.HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None] - ] = dataclasses.field(default=None, repr=False) + _description: Static[wrappers.HashlessObject[ModelDescription | None]] = ( + dataclasses.field(default=None, repr=False) + ) @property - def description(self) -> jaxsim.parsers.descriptions.ModelDescription: + def description(self) -> ModelDescription: return self._description.get() def __eq__(self, other: JaxSimModel) -> bool: @@ -116,7 +116,7 @@ def build_from_model_description( import jaxsim.parsers.rod # Parse the input resource (either a path to file or a string with the URDF/SDF) - # and build the -intermediate- model description + # and build the -intermediate- model description. intermediate_description = jaxsim.parsers.rod.build_model_description( model_description=model_description, is_urdf=is_urdf ) @@ -128,7 +128,7 @@ def build_from_model_description( considered_joints=considered_joints ) - # Build the model + # Build the model. model = JaxSimModel.build( model_description=intermediate_description, model_name=model_name, @@ -136,7 +136,7 @@ def build_from_model_description( contact_model=contact_model, ) - # Store the origin of the model, in case downstream logic needs it + # Store the origin of the model, in case downstream logic needs it. with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): model.built_from = model_description @@ -144,7 +144,7 @@ def build_from_model_description( @staticmethod def build( - model_description: jaxsim.parsers.descriptions.ModelDescription, + model_description: ModelDescription, model_name: str | None = None, *, terrain: jaxsim.terrain.Terrain | None = None, @@ -169,14 +169,14 @@ def build( """ from jaxsim.rbda.contacts.soft import SoftContacts - # Set the model name (if not provided, use the one from the model description) + # Set the model name (if not provided, use the one from the model description). model_name = model_name if model_name is not None else model_description.name - # Set the terrain (if not provided, use the default flat terrain) + # Set the terrain (if not provided, use the default flat terrain). terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default contact_model = contact_model or SoftContacts(terrain=terrain) - # Build the model + # Build the model. model = JaxSimModel( model_name=model_name, _description=wrappers.HashlessObject(obj=model_description), @@ -361,7 +361,7 @@ def reduce( considered_joints=list(considered_joints) ) - # Build the reduced model + # Build the reduced model. reduced_model = JaxSimModel.build( model_description=reduced_intermediate_description, model_name=model.name(), @@ -369,7 +369,7 @@ def reduce( contact_model=model.contact_model, ) - # Store the origin of the model, in case downstream logic needs it + # Store the origin of the model, in case downstream logic needs it. with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): reduced_model.built_from = model.built_from @@ -493,7 +493,7 @@ def generalized_free_floating_jacobian( case VelRepr.Inertial: W_H_B = data.base_transform() - B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag( B_X_W, jnp.eye(model.dofs()) @@ -507,7 +507,7 @@ def generalized_free_floating_jacobian( W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_J_full_WX_I = B_J_full_WX_BW = ( B_J_full_WX_B @@ -715,7 +715,7 @@ def to_active( # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. - C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB) match data.velocity_representation: @@ -797,21 +797,21 @@ def forward_dynamics_crb( # Prepare data # ============ - # Build joint torques if not provided + # Build joint torques if not provided. τ = ( jnp.atleast_1d(joint_forces) if joint_forces is not None else jnp.zeros_like(data.joint_positions()) ) - # Build external forces if not provided + # Build external forces if not provided. f = ( jnp.atleast_2d(link_forces) if link_forces is not None else jnp.zeros(shape=(model.number_of_links(), 6)) ) - # Compute terms of the floating-base EoM + # Compute terms of the floating-base EoM. M = free_floating_mass_matrix(model=model, data=data) h = free_floating_bias_forces(model=model, data=data) S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T @@ -848,7 +848,7 @@ def forward_dynamics_crb( # 6D transformation X. v̇_WB = ν̇[0:6].squeeze().astype(float) - # Extract the joint accelerations + # Extract the joint accelerations. s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float) return v̇_WB, s̈ @@ -880,7 +880,9 @@ def free_floating_mass_matrix( case VelRepr.Inertial: - B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint() + B_X_W = Adjoint.from_transform( + transform=data.base_transform(), inverse=True + ) invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) return invT.T @ M_body @ invT @@ -888,7 +890,7 @@ def free_floating_mass_matrix( case VelRepr.Mixed: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) return invT.T @ M_body @ invT @@ -1077,8 +1079,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ - W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint() - C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() + W_X_C = Adjoint.from_transform(transform=W_H_C) + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) C_v_WC = C_X_W @ W_v_WC # In Mixed representation, we need to include a cross product in ℝ⁶. @@ -1187,12 +1189,12 @@ def free_floating_gravity_forces( The free-floating gravity forces of the model. """ - # Build a zeroed state + # Build a zeroed state. data_rnea = js.data.JaxSimModelData.zero( model=model, velocity_representation=data.velocity_representation ) - # Set just the generalized position + # Set just the generalized position. with data_rnea.mutable_context( mutability=Mutability.MUTABLE, restore_after_exception=False ): @@ -1237,12 +1239,12 @@ def free_floating_bias_forces( The free-floating bias forces of the model. """ - # Build a zeroed state + # Build a zeroed state. data_rnea = js.data.JaxSimModelData.zero( model=model, velocity_representation=data.velocity_representation ) - # Set the generalized position and generalized velocity + # Set the generalized position and generalized velocity. with data_rnea.mutable_context( mutability=Mutability.MUTABLE, restore_after_exception=False ): @@ -1361,12 +1363,14 @@ def total_momentum_jacobian( B_Jh = B_Jh_B case VelRepr.Inertial: - B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint() + B_X_W = Adjoint.from_transform( + transform=data.base_transform(), inverse=True + ) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) case VelRepr.Mixed: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) case _: @@ -1378,14 +1382,14 @@ def total_momentum_jacobian( case VelRepr.Inertial: W_H_B = data.base_transform() - B_Xv_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint() + B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) W_Xf_B = B_Xv_W.T W_Jh = W_Xf_B @ B_Jh return W_Jh case VelRepr.Mixed: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_Xv_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint() + B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) BW_Xf_B = B_Xv_BW.T BW_Jh = BW_Xf_B @ B_Jh return BW_Jh @@ -1449,7 +1453,7 @@ def average_velocity_jacobian( W_p_CoM = js.com.com_position(model=model, data=data) W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - W_X_GW = jaxlie.SE3.from_matrix(W_H_GW).adjoint() + W_X_GW = Adjoint.from_transform(transform=W_H_GW) return W_X_GW @ GW_J @@ -1461,7 +1465,7 @@ def average_velocity_jacobian( B_R_W = data.base_orientation(dcm=True).transpose() B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) - B_X_GB = jaxlie.SE3.from_matrix(B_H_GB).adjoint() + B_X_GB = Adjoint.from_transform(transform=B_H_GB) return B_X_GB @ GB_J @@ -1472,7 +1476,7 @@ def average_velocity_jacobian( W_p_CoM = js.com.com_position(model=model, data=data) BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) - BW_X_GW = jaxlie.SE3.from_matrix(BW_H_GW).adjoint() + BW_X_GW = Adjoint.from_transform(transform=BW_H_GW) return BW_X_GW @ GW_J @@ -1518,8 +1522,8 @@ def other_representation_to_inertial( expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ - W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint() - C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint() + W_X_C = Adjoint.from_transform(transform=W_H_C) + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. @@ -1606,7 +1610,7 @@ def other_representation_to_inertial( # not remove gravity during the propagation. # Initialize the loop. - Carry = tuple[jtp.MatrixJax, jtp.MatrixJax] + Carry = tuple[jtp.Matrix, jtp.Matrix] carry0: Carry = (L_v_WL, L_v̇_WL) def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]: @@ -1832,8 +1836,8 @@ def step( integrator_state: The state of the integrator. joint_forces: The joint forces to consider. link_forces: - The link 6D forces to consider. - The frame in which they are expressed must be `data.velocity_representation`. + The 6D forces to apply to the links expressed in the frame corresponding to + the velocity representation of `data`. kwargs: Additional kwargs to pass to the integrator. Returns: diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 664d6cbb3..bd6636395 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -96,7 +96,9 @@ def system_velocity_dynamics( model: The model to consider. data: The data of the considered model. joint_forces: The joint forces to apply. - link_forces: The 6D forces to apply to the links. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding to + the velocity representation of `data`. Returns: A tuple containing the derivative of the base 6D velocity in inertial-fixed @@ -105,14 +107,16 @@ def system_velocity_dynamics( the system dynamics evaluation. """ - # Build joint torques if not provided + # Build joint torques if not provided. τ = ( jnp.atleast_1d(joint_forces.squeeze()) if joint_forces is not None else jnp.zeros_like(data.joint_positions()) ).astype(float) - # Build link forces if not provided + # Build link forces if not provided. + # These forces are expressed in the frame corresponding to the velocity + # representation of data. O_f_L = ( jnp.atleast_2d(link_forces.squeeze()) if link_forces is not None @@ -127,16 +131,17 @@ def system_velocity_dynamics( # with the terrain. W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float) - # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points, - # expressed in the world frame. - W_f_Ci = None + # Import privately the soft contacts classes. + from jaxsim.rbda.contacts.soft import SoftContactsState # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}. + assert isinstance(data.state.contact, SoftContactsState) ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float) if len(model.kin_dyn_parameters.contact_parameters.body) > 0: - # Compute the 6D forces applied to each collidable point and the - # corresponding material deformation rates. + + # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point + # and the corresponding material deformation rates. with data.switch_velocity_representation(VelRepr.Inertial): W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data) @@ -178,7 +183,7 @@ def system_velocity_dynamics( model.kin_dyn_parameters.joint_parameters.friction_viscous ).astype(float) - # Compute the joint friction torque + # Compute the joint friction torque. τ_friction = -( jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities) + jnp.diag(kv) @ data.state.physics_model.joint_velocities @@ -188,7 +193,7 @@ def system_velocity_dynamics( # Compute forward dynamics # ======================== - # Compute the total joint forces + # Compute the total joint forces. τ_total = τ + τ_friction + τ_position_limit references = js.references.JaxSimModelReferences.build( @@ -202,7 +207,7 @@ def system_velocity_dynamics( with references.switch_velocity_representation(VelRepr.Inertial): W_f_L = references.link_forces(model=model, data=data) - # Compute the total external 6D forces applied to the links + # Compute the total external 6D forces applied to the links. W_f_L_total = W_f_L + W_f_Li_terrain # - Joint accelerations: s̈ ∈ ℝⁿ @@ -273,7 +278,9 @@ def system_dynamics( model: The model to consider. data: The data of the considered model. joint_forces: The joint forces to apply. - link_forces: The 6D forces to apply to the links. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding to + the velocity representation of `data`. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient used to adjust the norm of the quaternion (only used in integrators not operating on the SO(3) manifold). diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index c34af9ebe..766f7f926 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -31,8 +31,8 @@ class ODEInput(JaxsimDataclass): @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, - joint_forces: jtp.VectorJax | None = None, - link_forces: jtp.MatrixJax | None = None, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, ) -> ODEInput: """ Build an `ODEInput` from a `JaxSimModel`. @@ -160,7 +160,7 @@ def build_from_jaxsim_model( `JaxSimModel` and initialized to zero. """ - # Get the contact model from the `JaxSimModel` + # Get the contact model from the `JaxSimModel`. match model.contact_model: case SoftContacts(): contact = SoftContactsState.build_from_jaxsim_model( @@ -212,7 +212,7 @@ def build( else PhysicsModelState.zero(model=model) ) - # Get the contact model from the `JaxSimModel` + # Get the contact model from the `JaxSimModel`. match contact: case SoftContactsState(): pass @@ -423,7 +423,7 @@ def build( base_angular_velocity=jnp.array(base_angular_velocity, dtype=float), ) - # assert state.valid(physics_model) + # TODO (diegoferigo): assert state.valid(physics_model) return physics_model_state @staticmethod @@ -501,14 +501,14 @@ class PhysicsModelInput(JaxsimDataclass): f_ext: The matrix of external forces applied to the links. """ - tau: jtp.VectorJax - f_ext: jtp.MatrixJax + tau: jtp.Vector + f_ext: jtp.Matrix @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, - joint_forces: jtp.VectorJax | None = None, - link_forces: jtp.MatrixJax | None = None, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, ) -> PhysicsModelInput: """ Build a `PhysicsModelInput` from a `JaxSimModel`. @@ -535,8 +535,8 @@ def build_from_jaxsim_model( @staticmethod def build( - joint_forces: jtp.VectorJax | None = None, - link_forces: jtp.MatrixJax | None = None, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, number_of_dofs: jtp.Int | None = None, number_of_links: jtp.Int | None = None, ) -> PhysicsModelInput: diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 5af59225f..719009571 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -27,8 +27,8 @@ # Generic types # ============= -Time = jax.typing.ArrayLike -TimeStep = jax.typing.ArrayLike +Time = jtp.FloatLike +TimeStep = jtp.FloatLike State = NextState = TypeVar("State") StateDerivative = TypeVar("StateDerivative") PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree) @@ -79,7 +79,7 @@ def build( The integrator object. """ - return cls(dynamics=dynamics, **kwargs) # noqa + return cls(dynamics=dynamics, **kwargs) def step( self, @@ -191,14 +191,14 @@ def init( class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]): # The Runge-Kutta matrix. - A: ClassVar[jax.typing.ArrayLike] + A: ClassVar[jtp.Matrix] # The weights coefficients. # Note that in practice we typically use its transpose `b.transpose()`. - b: ClassVar[jax.typing.ArrayLike] + b: ClassVar[jtp.Matrix] # The nodes coefficients. - c: ClassVar[jax.typing.ArrayLike] + c: ClassVar[jtp.Vector] # Define the order of the solution. # It should have as many elements as the number of rows of `b.transpose()`. @@ -384,7 +384,7 @@ def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]: # Define the computation of the Runge-Kutta stage. def compute_ki() -> jax.Array: - # Compute ∑ⱼ aᵢⱼ kⱼ + # Compute ∑ⱼ aᵢⱼ kⱼ. op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k) sum_ak = jax.tree_util.tree_map(op_sum_ak, K) @@ -440,7 +440,7 @@ def compute_ki() -> jax.Array: @staticmethod def butcher_tableau_is_valid( - A: jax.typing.ArrayLike, b: jax.typing.ArrayLike, c: jax.typing.ArrayLike + A: jtp.Matrix, b: jtp.Matrix, c: jtp.Vector ) -> jtp.Bool: """ Check if the Butcher tableau is valid. @@ -466,7 +466,7 @@ def butcher_tableau_is_valid( return valid @staticmethod - def butcher_tableau_is_explicit(A: jax.typing.ArrayLike) -> jtp.Bool: + def butcher_tableau_is_explicit(A: jtp.Matrix) -> jtp.Bool: """ Check if the Butcher tableau corresponds to an explicit integration scheme. @@ -481,9 +481,9 @@ def butcher_tableau_is_explicit(A: jax.typing.ArrayLike) -> jtp.Bool: @staticmethod def butcher_tableau_supports_fsal( - A: jax.typing.ArrayLike, - b: jax.typing.ArrayLike, - c: jax.typing.ArrayLike, + A: jtp.Matrix, + b: jtp.Matrix, + c: jtp.Vector, index_of_solution: jtp.IntLike = 0, ) -> [bool, int | None]: """ @@ -562,10 +562,9 @@ def post_process_state( # Indices to convert quaternions between serializations. to_xyzw = jnp.array([1, 2, 3, 0]) - to_wxyz = jnp.array([3, 0, 1, 2]) - # Get the initial quaternion. - W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw( + # Get the initial rotation. + W_R_B_t0 = jaxlie.SO3.from_quaternion_xyzw( xyzw=x0.physics_model.base_quaternion[to_xyzw] ) @@ -575,15 +574,13 @@ def post_process_state( # on the SO(3) manifold. W_ω_WB_tf = xf.physics_model.base_angular_velocity - # Integrate the quaternion on SO(3). + # Integrate the orientation on SO(3). # Note that we left-multiply with the exponential map since the angular # velocity is expressed in the inertial frame. - W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0 + W_R_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_R_B_t0 # Replace the quaternion in the final state. return xf.replace( - physics_model=xf.physics_model.replace( - base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz] - ), + physics_model=xf.physics_model.replace(base_quaternion=W_R_B_tf.wxyz), validate=True, ) diff --git a/src/jaxsim/integrators/fixed_step.py b/src/jaxsim/integrators/fixed_step.py index c5b18c071..9ec0ef477 100644 --- a/src/jaxsim/integrators/fixed_step.py +++ b/src/jaxsim/integrators/fixed_step.py @@ -1,10 +1,10 @@ from typing import ClassVar, Generic -import jax import jax.numpy as jnp import jax_dataclasses import jaxsim.api as js +import jaxsim.typing as jtp from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType @@ -18,11 +18,11 @@ @jax_dataclasses.pytree_dataclass class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): - A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float) + A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float) - b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose() + b: ClassVar[jtp.Matrix] = jnp.atleast_2d(1).astype(float).transpose() - c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float) + c: ClassVar[jtp.Vector] = jnp.atleast_1d(0).astype(float) row_index_of_solution: ClassVar[int] = 0 order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,) @@ -31,14 +31,14 @@ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): @jax_dataclasses.pytree_dataclass class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): - A: ClassVar[jax.typing.ArrayLike] = jnp.array( + A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0], [1, 0], ] ).astype(float) - b: ClassVar[jax.typing.ArrayLike] = ( + b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array([1 / 2, 1 / 2]), ) @@ -46,7 +46,7 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): .transpose() ) - c: ClassVar[jax.typing.ArrayLike] = jnp.array( + c: ClassVar[jtp.Vector] = jnp.array( [0, 1], ).astype(float) @@ -57,7 +57,7 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): @jax_dataclasses.pytree_dataclass class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): - A: ClassVar[jax.typing.ArrayLike] = jnp.array( + A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0, 0, 0], [1 / 2, 0, 0, 0], @@ -66,7 +66,7 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): ] ).astype(float) - b: ClassVar[jax.typing.ArrayLike] = ( + b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]), ) @@ -74,7 +74,7 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): .transpose() ) - c: ClassVar[jax.typing.ArrayLike] = jnp.array( + c: ClassVar[jtp.Vector] = jnp.array( [0, 1 / 2, 1 / 2, 1], ).astype(float) diff --git a/src/jaxsim/integrators/variable_step.py b/src/jaxsim/integrators/variable_step.py index 4e238064b..89a991440 100644 --- a/src/jaxsim/integrators/variable_step.py +++ b/src/jaxsim/integrators/variable_step.py @@ -312,9 +312,9 @@ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState: # Clip the estimated initial step size to the given bounds, if necessary. self.params["dt0"] = jnp.clip( - a=self.params["dt0"], - a_min=jnp.minimum(self.dt_min, self.params["dt0"]), - a_max=jnp.minimum(self.dt_max, self.params["dt0"]), + self.params["dt0"], + jnp.minimum(self.dt_min, self.params["dt0"]), + jnp.minimum(self.dt_max, self.params["dt0"]), ) # ========================================================= @@ -371,7 +371,7 @@ def while_loop_body(carry: Carry) -> Carry: # Shrink the Δt every time by the safety factor (even when accepted). # The β parameters define the bounds of the timestep update factor. - safety = jnp.clip(self.safety, a_min=0.0, a_max=1.0) + safety = jnp.clip(self.safety, 0.0, 1.0) β_min = jnp.maximum(0.0, self.beta_min) β_max = jnp.maximum(β_min, self.beta_max) @@ -383,9 +383,9 @@ def while_loop_body(carry: Carry) -> Carry: # In case of acceptance, Δt_next could either be larger than Δt0, # or slightly smaller than Δt0 depending on the safety factor. Δt_next = Δt0 * jnp.clip( - a=safety * jnp.power(1 / local_error, 1 / (q + 1)), - a_min=β_min, - a_max=β_max, + safety * jnp.power(1 / local_error, 1 / (q + 1)), + β_min, + β_max, ) def accept_step(): @@ -545,14 +545,14 @@ def build( @jax_dataclasses.pytree_dataclass class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): - A: ClassVar[jax.typing.ArrayLike] = jnp.array( + A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0], [1, 0], ] ).astype(float) - b: ClassVar[jax.typing.ArrayLike] = ( + b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array( [ @@ -565,7 +565,7 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): .transpose() ) - c: ClassVar[jax.typing.ArrayLike] = jnp.array( + c: ClassVar[jtp.Vector] = jnp.array( [0, 1], ).astype(float) @@ -578,7 +578,7 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): @jax_dataclasses.pytree_dataclass class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): - A: ClassVar[jax.typing.ArrayLike] = jnp.array( + A: ClassVar[jtp.Matrix] = jnp.array( [ [0, 0, 0, 0], [1 / 2, 0, 0, 0], @@ -587,7 +587,7 @@ class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mi ] ).astype(float) - b: ClassVar[jax.typing.ArrayLike] = ( + b: ClassVar[jtp.Matrix] = ( jnp.atleast_2d( jnp.array( [ @@ -600,7 +600,7 @@ class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mi .transpose() ) - c: ClassVar[jax.typing.ArrayLike] = jnp.array( + c: ClassVar[jtp.Vector] = jnp.array( [0, 1 / 2, 3 / 4, 1], ).astype(float) diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py index 6ede6ab91..2e7b9c352 100644 --- a/src/jaxsim/math/__init__.py +++ b/src/jaxsim/math/__init__.py @@ -4,8 +4,9 @@ from .adjoint import Adjoint from .cross import Cross from .inertia import Inertia -from .joint_model import JointModel, supported_joint_motion from .quaternion import Quaternion from .rotation import Rotation from .skew import Skew from .transform import Transform + +from .joint_model import JointModel, supported_joint_motion # isort:skip diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index e770cb42a..16f3914fe 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -11,6 +11,7 @@ from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms from .rotation import Rotation +from .transform import Transform @jax_dataclasses.pytree_dataclass @@ -162,7 +163,7 @@ def child_H_parent( joint_index=joint_index, joint_position=joint_position ) - i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix() + i_Hi_λ = Transform.inverse(λ_Hi_i) return i_Hi_λ, S diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index 6cd7c9ec9..3e2968dcf 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -58,9 +58,7 @@ def from_dcm(dcm: jtp.Matrix) -> jtp.Vector: Returns: jtp.Vector: Quaternion in XYZW representation. """ - return Quaternion.to_wxyz( - xyzw=jaxlie.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw() - ) + return jaxlie.SO3.from_matrix(matrix=dcm).wxyz @staticmethod def derivative( @@ -165,12 +163,8 @@ def integration( # Integrate the quaternion on the manifold. W_Q_B_tf = jax.lax.select( pred=omega_in_body_fixed, - on_true=Quaternion.to_wxyz( - xyzw=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).as_quaternion_xyzw() - ), - on_false=Quaternion.to_wxyz( - xyzw=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).as_quaternion_xyzw() - ), + on_true=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).wxyz, + on_false=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).wxyz, ) return W_Q_B_tf diff --git a/src/jaxsim/math/transform.py b/src/jaxsim/math/transform.py index f4e049059..fe82de21f 100644 --- a/src/jaxsim/math/transform.py +++ b/src/jaxsim/math/transform.py @@ -46,8 +46,8 @@ def from_quaternion_and_translation( @staticmethod def from_rotation_and_translation( - rotation: jtp.MatrixLike, - translation: jtp.VectorLike, + rotation: jtp.MatrixLike = jnp.eye(3), + translation: jtp.VectorLike = jnp.zeros(3), inverse: jtp.BoolLike = False, ) -> jtp.Matrix: """ diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index d2af6bd0b..b6986796e 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -160,7 +160,7 @@ def convert( considered_joints: list[str] | None = None, plane_normal: tuple[float, float, float] = (0, 0, 1), heightmap: bool | None = None, - cameras: list[dict[str, str]] | dict[str, str] = None, + cameras: list[dict[str, str]] | dict[str, str] | None = None, ) -> tuple[str, dict[str, Any]]: """ Converts a ROD model to a Mujoco MJCF string. @@ -274,7 +274,7 @@ def convert( # Load the URDF model into Mujoco. assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model) - mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets) # noqa + mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets) # Get the joint names. mj_joint_names = set( @@ -306,7 +306,7 @@ def convert( root: ET._Element = tree.getroot() # Find the element (might be the root itself). - mujoco_element: ET._Element = list(root.iter("mujoco"))[0] + mujoco_element: ET._Element = next(iter(root.iter("mujoco"))) # -------------- # Add the motors @@ -516,7 +516,7 @@ def convert( model_name: str | None = None, plane_normal: tuple[float, float, float] = (0, 0, 1), heightmap: bool | None = None, - cameras: list[dict[str, str]] | dict[str, str] = None, + cameras: list[dict[str, str]] | dict[str, str] | None = None, ) -> tuple[str, dict[str, Any]]: """ Converts a URDF file to a Mujoco MJCF string. @@ -558,7 +558,7 @@ def convert( model_name: str | None = None, plane_normal: tuple[float, float, float] = (0, 0, 1), heightmap: bool | None = None, - cameras: list[dict[str, str]] | dict[str, str] = None, + cameras: list[dict[str, str]] | dict[str, str] | None = None, ) -> tuple[str, dict[str, Any]]: """ Converts a SDF file to a Mujoco MJCF string. diff --git a/src/jaxsim/mujoco/model.py b/src/jaxsim/mujoco/model.py index b4f677a3e..6b61177ff 100644 --- a/src/jaxsim/mujoco/model.py +++ b/src/jaxsim/mujoco/model.py @@ -31,16 +31,16 @@ def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None: self.model = model self.data = data if data is not None else mj.MjData(self.model) - # Populate the data with kinematics + # Populate the data with kinematics. mj.mj_forward(self.model, self.data) - # Keep the cache of this method local to improve GC + # Keep the cache of this method local to improve GC. self.mask_qpos = functools.cache(self._mask_qpos) @staticmethod def build_from_xml( mjcf_description: str | pathlib.Path, - assets: dict[str, Any] = None, + assets: dict[str, Any] | None = None, heightmap: HeightmapCallable | None = None, ) -> MujocoModelHelper: """ @@ -56,15 +56,15 @@ def build_from_xml( A MujocoModelHelper object. """ - # Read the XML description if it's a path to file + # Read the XML description if it is a path to file. mjcf_description = ( mjcf_description.read_text() if isinstance(mjcf_description, pathlib.Path) else mjcf_description ) - # Create the Mujoco model from the XML and, optionally, the assets dictionary - model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa + # Create the Mujoco model from the XML and, optionally, the assets dictionary. + model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) data = mj.MjData(model) if heightmap: diff --git a/src/jaxsim/mujoco/visualizer.py b/src/jaxsim/mujoco/visualizer.py index 810391f36..bbf2f400d 100644 --- a/src/jaxsim/mujoco/visualizer.py +++ b/src/jaxsim/mujoco/visualizer.py @@ -81,6 +81,9 @@ def record_frame(self, camera_name: str | None = None) -> None: def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None: """Writes the video to a file.""" + # Resolve the path to the video. + path = path.expanduser().resolve() + if path.is_dir(): raise IsADirectoryError(f"The path '{path}' is a directory.") diff --git a/src/jaxsim/parsers/__init__.py b/src/jaxsim/parsers/__init__.py index 8238b3e4e..e69de29bb 100644 --- a/src/jaxsim/parsers/__init__.py +++ b/src/jaxsim/parsers/__init__.py @@ -1 +0,0 @@ -from . import descriptions, kinematic_graph diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 610b8c4fc..00967381c 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -26,7 +26,7 @@ class JointGenericAxis: A joint requiring the specification of a 3D axis. """ - #: The axis of rotation or translation of the joint (must have norm 1). + # The axis of rotation or translation of the joint (must have norm 1). axis: jtp.Vector def __hash__(self) -> int: diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 41f5399df..1c486e5cb 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -4,11 +4,11 @@ import jax.numpy as jnp import jax_dataclasses -import jaxlie import numpy as np from jax_dataclasses import Static import jaxsim.typing as jtp +from jaxsim.math import Adjoint from jaxsim.utils import JaxsimDataclass @@ -102,12 +102,11 @@ def lump_with( The combined link. """ - # Get the 6D inertia of the link to remove + # Get the 6D inertia of the link to remove. I_removed = link.inertia # Create the SE3 object. Note the inverse. - r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse() - r_X_l = r_H_l.adjoint() + r_X_l = Adjoint.from_transform(transform=lumped_H_removed, inverse=True) # Move the inertia I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index ac488104d..3ba4b0206 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -176,7 +176,7 @@ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription: frames=self.frames, collisions=tuple(self.collision_shapes), fixed_base=self.fixed_base, - base_link_name=list(iter(self))[0].name, + base_link_name=next(iter(self)).name, model_pose=self.root_pose, considered_joints=considered_joints, ) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 1e289d931..9edf811b4 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -12,7 +12,8 @@ from jaxsim import logging from jaxsim.utils import Mutability -from . import descriptions +from .descriptions.joint import JointDescription, JointType +from .descriptions.link import LinkDescription @dataclasses.dataclass @@ -61,7 +62,7 @@ def __eq__(self, other: RootPose) -> bool: @dataclasses.dataclass(frozen=True) -class KinematicGraph(Sequence[descriptions.LinkDescription]): +class KinematicGraph(Sequence[LinkDescription]): """ Class storing a kinematic graph having links as nodes and joints as edges. @@ -72,11 +73,11 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]): root_pose: The pose of the kinematic graph's root. """ - root: descriptions.LinkDescription - frames: list[descriptions.LinkDescription] = dataclasses.field( + root: LinkDescription + frames: list[LinkDescription] = dataclasses.field( default_factory=list, hash=False, compare=False ) - joints: list[descriptions.JointDescription] = dataclasses.field( + joints: list[JointDescription] = dataclasses.field( default_factory=list, hash=False, compare=False ) @@ -89,26 +90,26 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]): # Private attribute storing the unconnected joints from the parsed model and # the joints removed after model reduction. - _joints_removed: list[descriptions.JointDescription] = dataclasses.field( + _joints_removed: list[JointDescription] = dataclasses.field( default_factory=list, repr=False, hash=False, compare=False ) @functools.cached_property - def links_dict(self) -> dict[str, descriptions.LinkDescription]: + def links_dict(self) -> dict[str, LinkDescription]: return {l.name: l for l in iter(self)} @functools.cached_property - def frames_dict(self) -> dict[str, descriptions.LinkDescription]: + def frames_dict(self) -> dict[str, LinkDescription]: return {f.name: f for f in self.frames} @functools.cached_property - def joints_dict(self) -> dict[str, descriptions.JointDescription]: + def joints_dict(self) -> dict[str, JointDescription]: return {j.name: j for j in self.joints} @functools.cached_property def joints_connection_dict( self, - ) -> dict[tuple[str, str], descriptions.JointDescription]: + ) -> dict[tuple[str, str], JointDescription]: return {(j.parent.name, j.child.name): j for j in self.joints} def __post_init__(self) -> None: @@ -158,9 +159,9 @@ def __post_init__(self) -> None: @staticmethod def build_from( - links: list[descriptions.LinkDescription], - joints: list[descriptions.JointDescription], - frames: list[descriptions.LinkDescription] | None = None, + links: list[LinkDescription], + joints: list[JointDescription], + frames: list[LinkDescription] | None = None, root_link_name: str | None = None, root_pose: RootPose = RootPose(), ) -> KinematicGraph: @@ -186,7 +187,7 @@ def build_from( logging.debug(msg=f"Assuming '{root_link_name}' as the root link") # Couple links and joints and create the graph of links. - # Note that the pose of the frames is not updated; it's the caller's + # Note that the pose of the frames is not updated; it is the caller's # responsibility to update their pose if they want to use them. ( graph_root_node, @@ -218,17 +219,17 @@ def build_from( @staticmethod def _create_graph( - links: list[descriptions.LinkDescription], - joints: list[descriptions.JointDescription], + links: list[LinkDescription], + joints: list[JointDescription], root_link_name: str, - frames: list[descriptions.LinkDescription] | None = None, + frames: list[LinkDescription] | None = None, ) -> tuple[ - descriptions.LinkDescription, - list[descriptions.JointDescription], - list[descriptions.LinkDescription], - list[descriptions.LinkDescription], - list[descriptions.JointDescription], - list[descriptions.LinkDescription], + LinkDescription, + list[JointDescription], + list[LinkDescription], + list[LinkDescription], + list[JointDescription], + list[LinkDescription], ]: """ Low-level creator of kinematic graph components. @@ -248,7 +249,7 @@ def _create_graph( """ # Create a dictionary that maps the link name to the link, for easy retrieval. - links_dict: dict[str, descriptions.LinkDescription] = { + links_dict: dict[str, LinkDescription] = { l.name: l.mutable(validate=False) for l in links } @@ -280,7 +281,7 @@ def _create_graph( # Couple links and joints creating the kinematic graph. for joint in joints: - # Get the parent and child links of the joint + # Get the parent and child links of the joint. parent_link = links_dict[joint.parent.name] child_link = links_dict[joint.child.name] @@ -293,7 +294,7 @@ def _create_graph( # Assign link's children and make sure they are unique. if child_link.name not in {l.name for l in parent_link.children}: with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION): - parent_link.children = parent_link.children + (child_link,) + parent_link.children = (*parent_link.children, child_link) # Collect all the links of the kinematic graph. all_links_in_graph = list( @@ -641,7 +642,7 @@ def print_tree(self) -> None: ) @property - def joints_removed(self) -> list[descriptions.JointDescription]: + def joints_removed(self) -> list[JointDescription]: """ Get the list of joints removed during the graph reduction. @@ -653,9 +654,9 @@ def joints_removed(self) -> list[descriptions.JointDescription]: @staticmethod def breadth_first_search( - root: descriptions.LinkDescription, + root: LinkDescription, sort_children: Callable[[Any], Any] | None = lambda link: link.name, - ) -> Iterable[descriptions.LinkDescription]: + ) -> Iterable[LinkDescription]: """ Perform a breadth-first search (BFS) traversal of the kinematic graph. @@ -698,25 +699,25 @@ def breadth_first_search( # Sequence protocol # ================= - def __iter__(self) -> Iterable[descriptions.LinkDescription]: + def __iter__(self) -> Iterable[LinkDescription]: yield from KinematicGraph.breadth_first_search(root=self.root) - def __reversed__(self) -> Iterable[descriptions.LinkDescription]: + def __reversed__(self) -> Iterable[LinkDescription]: yield from reversed(list(iter(self))) def __len__(self) -> int: return len(list(iter(self))) - def __contains__(self, item: str | descriptions.LinkDescription) -> bool: + def __contains__(self, item: str | LinkDescription) -> bool: if isinstance(item, str): return item in self.link_names() - if isinstance(item, descriptions.LinkDescription): + if isinstance(item, LinkDescription): return item in set(iter(self)) raise TypeError(type(item).__name__) - def __getitem__(self, key: int | str) -> descriptions.LinkDescription: + def __getitem__(self, key: int | str) -> LinkDescription: if isinstance(key, str): if key not in self.link_names(): raise KeyError(key) @@ -731,12 +732,10 @@ def __getitem__(self, key: int | str) -> descriptions.LinkDescription: raise TypeError(type(key).__name__) - def count(self, value: descriptions.LinkDescription) -> int: + def count(self, value: LinkDescription) -> int: return list(iter(self)).count(value) - def index( - self, value: descriptions.LinkDescription, start: int = 0, stop: int = -1 - ) -> int: + def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int: return list(iter(self)).index(value, start, stop) @@ -906,7 +905,7 @@ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: @staticmethod def pre_H_suc( - joint_type: descriptions.JointType, + joint_type: JointType, joint_axis: npt.NDArray, joint_position: float | None = None, ) -> npt.NDArray: diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 752345cd3..5fd0f4822 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -54,19 +54,19 @@ def extract_model_data( if isinstance(model_description, rod.Model): sdf_model = model_description else: - # Parse the SDF resource + # Parse the SDF resource. sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) if len(sdf_element.models()) == 0: raise RuntimeError("Failed to find any model in SDF resource") - # Assume the SDF resource has only one model, or the desired model name is given + # Assume the SDF resource has only one model, or the desired model name is given. sdf_models = {m.name: m for m in sdf_element.models()} sdf_model = ( sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name] ) - # Log model name + # Log model name. logging.debug(msg=f"Found model '{sdf_model.name}' in SDF resource") # Jaxsim supports only models compatible with URDF, i.e. those having all links @@ -75,7 +75,7 @@ def extract_model_data( # pose is expressed wrt the parent link they are rigidly attached to. sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf) - # Log type of base link + # Log type of base link. logging.debug( msg="Model '{}' is {}".format( sdf_model.name, @@ -83,7 +83,7 @@ def extract_model_data( ) ) - # Log detected base link + # Log detected base link. logging.debug(msg=f"Considering '{sdf_model.get_canonical_link()}' as base link") # Pose of the model @@ -101,7 +101,7 @@ def extract_model_data( # Parse links # =========== - # Parse the links (unconnected) + # Parse the links (unconnected). links = [ descriptions.LinkDescription( name=l.name, @@ -113,14 +113,14 @@ def extract_model_data( if l.inertial.mass > 0 ] - # Create a dictionary to find easily links + # Create a dictionary to find easily links. links_dict: Dict[str, descriptions.LinkDescription] = {l.name: l for l in links} # ============ # Parse frames # ============ - # Parse the frames (unconnected) + # Parse the frames (unconnected). frames = [ descriptions.LinkDescription( name=f.name, @@ -138,7 +138,7 @@ def extract_model_data( # ========================= # In this case, we need to get the pose of the joint that connects the base link - # to the world and combine their pose + # to the world and combine their pose. if sdf_model.is_fixed_base(): # Create a massless word link world_link = descriptions.LinkDescription( @@ -200,7 +200,7 @@ def extract_model_data( # Parse joints # ============ - # Check that all joint poses are expressed w.r.t. their parent link + # Check that all joint poses are expressed w.r.t. their parent link. for j in sdf_model.joints(): if j.pose is None: continue @@ -215,7 +215,7 @@ def extract_model_data( msg = "Pose of joint '{}' is not expressed wrt its parent link '{}'" raise ValueError(msg.format(j.name, j.parent)) - # Parse the joints + # Parse the joints. joints = [ descriptions.JointDescription( name=j.name, @@ -278,10 +278,10 @@ def extract_model_data( and j.child in links_dict.keys() ] - # Create a dictionary to find the parent joint of the links + # Create a dictionary to find the parent joint of the links. joint_dict = {j.child.name: j.name for j in joints} - # Check that all the link poses are expressed wrt their parent joint + # Check that all the link poses are expressed wrt their parent joint. for l in sdf_model.links(): if l.name not in links_dict: continue @@ -354,7 +354,7 @@ def build_model_description( The parsed model description. """ - # Parse data from the SDF assuming it contains a single model + # Parse data from the SDF assuming it contains a single model. sdf_data = extract_model_data( model_description=model_description, model_name=None, is_urdf=is_urdf ) diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index d5ebbb15e..242bb0a2e 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,12 +1,11 @@ import os -import jaxlie import numpy as np import numpy.typing as npt import rod import jaxsim.typing as jtp -from jaxsim.math import Inertia +from jaxsim.math import Adjoint, Inertia from jaxsim.parsers import descriptions @@ -21,10 +20,10 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: The 6D inertia matrix of the link expressed in the link frame. """ - # Extract the "mass" element + # Extract the "mass" element. m = inertial.mass - # Extract the "inertia" element + # Extract the "inertia" element. inertia_element = inertial.inertia ixx = inertia_element.ixx @@ -34,7 +33,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0 iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0 - # Build the 3x3 inertia matrix expressed in the CoM + # Build the 3x3 inertia matrix expressed in the CoM. I_CoM = np.array( [ [ixx, ixy, ixz], @@ -43,17 +42,16 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: ] ) - # Build the 6x6 generalized inertia at the CoM + # Build the 6x6 generalized inertia at the CoM. M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM) - # Compute the transform from the inertial frame (CoM) to the link frame + # Compute the transform from the inertial frame (CoM) to the link frame. L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4) - # We need its inverse - CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse() - CoM_X_L = CoM_H_L.adjoint() + # We need its inverse. + CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True) - # Express the CoM inertia matrix in the link frame L + # Express the CoM inertia matrix in the link frame L. M_L = CoM_X_L.T @ M_CoM @ CoM_X_L return M_L.astype(dtype=float) diff --git a/src/jaxsim/rbda/aba.py b/src/jaxsim/rbda/aba.py index 355850c2e..1e8a9c510 100644 --- a/src/jaxsim/rbda/aba.py +++ b/src/jaxsim/rbda/aba.py @@ -102,7 +102,7 @@ def aba( i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) i_X_0 = i_X_0.at[0].set(jnp.eye(6)) - # Initialize base quantities + # Initialize base quantities. if model.floating_base(): # Base velocity v₀ in body-fixed representation. @@ -121,10 +121,7 @@ def aba( # Pass 1 # ====== - Pass1Carry = tuple[ - jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax - ] - + Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0) # Propagate kinematics and initialize AB inertia and AB bias forces. @@ -178,10 +175,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: d = jnp.zeros(shape=(model.number_of_links(), 1)) u = jnp.zeros(shape=(model.number_of_links(), 1)) - Pass2Carry = tuple[ - jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax - ] - + Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_2_carry: Pass2Carry = (U, d, u, MA, pA) def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: @@ -204,8 +198,8 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: # Propagate them to the parent, handling the base link. def propagate( - MA_pA: tuple[jtp.MatrixJax, jtp.MatrixJax] - ) -> tuple[jtp.MatrixJax, jtp.MatrixJax]: + MA_pA: tuple[jtp.Matrix, jtp.Matrix] + ) -> tuple[jtp.Matrix, jtp.Matrix]: MA, pA = MA_pA @@ -248,7 +242,7 @@ def propagate( s̈ = jnp.zeros_like(s) a = jnp.zeros_like(v).at[0].set(a0) - Pass3Carry = tuple[jtp.MatrixJax, jtp.VectorJax] + Pass3Carry = tuple[jtp.Matrix, jtp.Vector] pass_3_carry = (a, s̈) def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index aa5e6135b..4e6800782 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -80,7 +80,7 @@ def collidable_points_pos_vel( # Propagate kinematics # ==================== - PropagateTransformsCarry = tuple[jtp.MatrixJax, jtp.Matrix] + PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix] propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi) def propagate_kinematics( @@ -97,7 +97,7 @@ def propagate_kinematics( W_Xi_i = W_X_i[λ[i]] @ λi_X_i W_X_i = W_X_i.at[i].set(W_Xi_i) - # Propagate the 6D velocity + # Propagate the 6D velocity. W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze() W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi) @@ -118,14 +118,15 @@ def propagate_kinematics( # ================================================== def process_point_kinematics( - Li_p_C: jtp.VectorJax, parent_body: jtp.Int - ) -> tuple[jtp.VectorJax, jtp.VectorJax]: - # Compute the position of the collidable point + Li_p_C: jtp.Vector, parent_body: jtp.Int + ) -> tuple[jtp.Vector, jtp.Vector]: + + # Compute the position of the collidable point. W_p_Ci = ( Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1]) )[0:3] - # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci} + # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}. CW_vl_WCi = ( jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()]) @ W_v_Wi[parent_body].squeeze() @@ -133,7 +134,7 @@ def process_point_kinematics( return W_p_Ci, CW_vl_WCi - # Process all the collidable points in parallel + # Process all the collidable points in parallel. W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( model.kin_dyn_parameters.contact_parameters.point, jnp.array(model.kin_dyn_parameters.contact_parameters.body), diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 0328eef76..08e5b7bac 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -105,24 +105,24 @@ def build_default_from_jaxsim_model( - ξ < 1.0: under-damped """ - # Use symbols for input parameters + # Use symbols for input parameters. ξ = damping_ratio δ_max = max_penetration μc = static_friction_coefficient - # Compute the total mass of the model + # Compute the total mass of the model. m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() - # Rename the standard gravity + # Rename the standard gravity. g = standard_gravity - # Compute the average support force on each collidable point + # Compute the average support force on each collidable point. f_average = m * g / number_of_active_collidable_points_steady_state - # Compute the stiffness to get the desired steady-state penetration + # Compute the stiffness to get the desired steady-state penetration. K = f_average / jnp.power(δ_max, 3 / 2) - # Compute the damping using the damping ratio + # Compute the damping using the damping ratio. critical_damping = 2 * jnp.sqrt(K * m) D = ξ * critical_damping @@ -151,14 +151,16 @@ class SoftContacts(ContactModel): default_factory=SoftContactsParams ) - terrain: Terrain = dataclasses.field(default_factory=FlatTerrain) + terrain: jax_dataclasses.Static[Terrain] = dataclasses.field( + default_factory=FlatTerrain + ) def compute_contact_forces( self, position: jtp.Vector, velocity: jtp.Vector, tangential_deformation: jtp.Vector, - ) -> tuple[jtp.Vector, tuple[jtp.Vector, None]]: + ) -> tuple[jtp.Vector, tuple[jtp.Vector]]: """ Compute the contact forces and material deformation rate. @@ -188,18 +190,18 @@ def compute_contact_forces( # Normal force computation # ======================== - # Unpack the position of the collidable point + # Unpack the position of the collidable point. px, py, pz = W_p_C = position.squeeze() vx, vy, vz = W_ṗ_C = velocity.squeeze() - # Compute the terrain normal and the contact depth + # Compute the terrain normal and the contact depth. n̂ = self.terrain.normal(x=px, y=py).squeeze() h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz]) - # Compute the penetration depth normal to the terrain + # Compute the penetration depth normal to the terrain. δ = jnp.maximum(0.0, jnp.dot(h, n̂)) - # Compute the penetration normal velocity + # Compute the penetration normal velocity. δ̇ = -jnp.dot(W_ṗ_C, n̂) # Non-linear spring-damper model. @@ -210,10 +212,10 @@ def compute_contact_forces( on_false=jnp.array(0.0), ) - # Prevent negative normal forces that might occur when δ̇ is largely negative + # Prevent negative normal forces that might occur when δ̇ is largely negative. force_normal_mag = jnp.maximum(0.0, force_normal_mag) - # Compute the 3D linear force in C[W] frame + # Compute the 3D linear force in C[W] frame. force_normal = force_normal_mag * n̂ # ==================================== @@ -230,11 +232,11 @@ def compute_contact_forces( ) def with_no_friction(): - # Compute 6D mixed force in C[W] + # Compute 6D mixed force in C[W]. CW_f_lin = force_normal CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)]) - # Compute lin-ang 6D forces (inertial representation) + # Compute lin-ang 6D forces (inertial representation). W_f = W_Xf_CW @ CW_f return W_f, (ṁ,) @@ -258,32 +260,32 @@ def above_terrain(): return jnp.zeros(6), (ṁ,) def below_terrain(): - # Decompose the velocity in normal and tangential components + # Decompose the velocity in normal and tangential components. v_normal = jnp.dot(W_ṗ_C, n̂) * n̂ v_tangential = W_ṗ_C - v_normal - # Compute the tangential force. If inside the friction cone, the contact + # Compute the tangential force. If inside the friction cone, the contact. f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential) def sticking_contact(): - # Sum the normal and tangential forces, and create the 6D force + # Sum the normal and tangential forces, and create the 6D force. CW_f_stick = force_normal + f_tangential CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)]) - # In this case the 3D material deformation is the tangential velocity + # In this case the 3D material deformation is the tangential velocity. ṁ = v_tangential # Return the 6D force in the contact frame and - # the deformation derivative + # the deformation derivative. return CW_f, ṁ def slipping_contact(): - # Project the force to the friction cone boundary + # Project the force to the friction cone boundary. f_tangential_projected = (μ * force_normal_mag) * ( f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9) ) - # Sum the normal and tangential forces, and create the 6D force + # Sum the normal and tangential forces, and create the 6D force. CW_f_slip = force_normal + f_tangential_projected CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)]) @@ -297,7 +299,7 @@ def slipping_contact(): ṁ = (f_tangential_projected - α * m) / β # Return the 6D force in the contact frame and - # the deformation derivative + # the deformation derivative. return CW_f, ṁ CW_f, ṁ = jax.lax.cond( @@ -307,10 +309,10 @@ def slipping_contact(): operand=None, ) - # Express the 6D force in the world frame + # Express the 6D force in the world frame. W_f = W_Xf_CW @ CW_f - # Return the 6D force in the world frame and the deformation derivative + # Return the 6D force in the world frame and the deformation derivative. return W_f, (ṁ,) # (W_f, (ṁ,)) @@ -321,7 +323,7 @@ def slipping_contact(): operand=None, ) - # (W_f, ṁ) + # (W_f, (ṁ,)) return jax.lax.cond( pred=(μ == 0.0), true_fun=lambda _: with_no_friction(), diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index 27ee83042..904048832 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -45,7 +45,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat # Propagate kinematics # ==================== - ForwardPassCarry = tuple[jtp.MatrixJax] + ForwardPassCarry = tuple[jtp.Matrix] forward_pass_carry: ForwardPassCarry = (i_X_0,) def propagate_kinematics( @@ -71,7 +71,7 @@ def propagate_kinematics( M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs())) - BackwardPassCarry = tuple[jtp.MatrixJax, jtp.MatrixJax] + BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix] backward_pass_carry: BackwardPassCarry = (Mc, M) def backward_pass( @@ -90,7 +90,7 @@ def backward_pass( j = i - CarryInnerFn = tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax] + CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix] carry_inner_fn = (j, Fi, M) def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn: diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 8bcab038a..cdfbc35a3 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -61,7 +61,7 @@ def forward_kinematics_model( # Propagate the kinematics # ======================== - PropagateKinematicsCarry = tuple[jtp.MatrixJax] + PropagateKinematicsCarry = tuple[jtp.Matrix] propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,) def propagate_kinematics( diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index 531c921a2..197a45ee2 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -50,7 +50,7 @@ def jacobian( # Propagate kinematics # ==================== - PropagateKinematicsCarry = tuple[jtp.MatrixJax] + PropagateKinematicsCarry = tuple[jtp.Matrix] propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,) def propagate_kinematics( @@ -86,9 +86,9 @@ def propagate_kinematics( # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True. κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index] - def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> tuple[jtp.MatrixJax, None]: + def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]: - def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax: + def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix: ii = i - 1 @@ -155,16 +155,16 @@ def jacobian_full_doubly_left( B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) B_X_i = B_X_i.at[0].set(jnp.eye(6)) - # ============================= - # Compute doubly-left Jacobian - # ============================= + # ================================= + # Compute doubly-left full Jacobian + # ================================= # Allocate the Jacobian matrix. # The Jbb section of the doubly-left Jacobian is an identity matrix. J = jnp.zeros(shape=(6, 6 + model.dofs())) J = J.at[0:6, 0:6].set(jnp.eye(6)) - ComputeFullJacobianCarry = tuple[jtp.MatrixJax, jtp.MatrixJax] + ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix] compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J) def compute_full_jacobian( @@ -261,7 +261,7 @@ def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix: J̇ = jnp.zeros(shape=(6, 6 + model.dofs())) ComputeFullJacobianDerivativeCarry = tuple[ - jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax + jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix ] compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = ( diff --git a/src/jaxsim/rbda/rnea.py b/src/jaxsim/rbda/rnea.py index b5f927d1e..625f8fede 100644 --- a/src/jaxsim/rbda/rnea.py +++ b/src/jaxsim/rbda/rnea.py @@ -132,7 +132,7 @@ def rnea( # Pass 1 # ====== - ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax] + ForwardPassCarry = Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f) def forward_pass( @@ -186,7 +186,7 @@ def forward_pass( τ = jnp.zeros_like(s) - BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax] + BackwardPassCarry = Tuple[jtp.Vector, jtp.Matrix] backward_pass_carry: BackwardPassCarry = (τ, f) def backward_pass( @@ -201,7 +201,7 @@ def backward_pass( τ = τ.at[ii].set(τ_i.squeeze()) # Propagate the force to the parent link. - def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax: + def update_f(f: jtp.Matrix) -> jtp.Matrix: f_λi = f[λ[i]] + i_X_λi[i].T @ f[i] f = f.at[λ[i]].set(f_λi) diff --git a/src/jaxsim/rbda/utils.py b/src/jaxsim/rbda/utils.py index 19e4e87f3..0f209db78 100644 --- a/src/jaxsim/rbda/utils.py +++ b/src/jaxsim/rbda/utils.py @@ -19,7 +19,7 @@ def process_inputs( joint_accelerations: jtp.VectorLike | None = None, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, - standard_gravity: jtp.VectorLike | None = None, + standard_gravity: jtp.ScalarLike | None = None, ) -> tuple[ jtp.Vector, jtp.Vector, diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index c1f44081e..9f7378476 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import abc +import dataclasses import jax.numpy as jnp import jax_dataclasses @@ -7,22 +10,23 @@ class Terrain(abc.ABC): + delta = 0.010 @abc.abstractmethod - def height(self, x: float, y: float) -> float: + def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: pass - def normal(self, x: float, y: float) -> jtp.Vector: + def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: """ Compute the normal vector of the terrain at a specific (x, y) location. Args: - x (float): The x-coordinate of the location. - y (float): The y-coordinate of the location. + x: The x-coordinate of the location. + y: The y-coordinate of the location. Returns: - jtp.Vector: The normal vector of the terrain surface at the specified location. + The normal vector of the terrain surface at the specified location. """ # https://stackoverflow.com/a/5282364 @@ -40,43 +44,117 @@ def normal(self, x: float, y: float) -> jtp.Vector: @jax_dataclasses.pytree_dataclass class FlatTerrain(Terrain): - def height(self, x: float, y: float) -> float: - return 0.0 + + z: float = dataclasses.field(default=0.0, kw_only=True) + + @staticmethod + def build(height: jtp.FloatLike) -> FlatTerrain: + + return FlatTerrain(z=float(height)) + + def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: + + return jnp.array(self.z, dtype=float) + + def __hash__(self) -> int: + + return hash(self.z) + + def __eq__(self, other: FlatTerrain) -> bool: + + if not isinstance(other, FlatTerrain): + return False + + return self.z == other.z @jax_dataclasses.pytree_dataclass -class PlaneTerrain(Terrain): - plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0]) +class PlaneTerrain(FlatTerrain): + + plane_normal: tuple[float, float, float] = jax_dataclasses.field( + default=(0.0, 0.0, 0.0), kw_only=True + ) @staticmethod - def build(plane_normal: list) -> "PlaneTerrain": + def build( + plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0 + ) -> PlaneTerrain: """ Create a PlaneTerrain instance with a specified plane normal vector. Args: - plane_normal (list): The normal vector of the terrain plane. + plane_normal: The normal vector of the terrain plane. + plane_height_over_origin: The height of the plane over the origin. Returns: PlaneTerrain: A PlaneTerrain instance. """ - if not isinstance(plane_normal, list): - raise TypeError( - f"Expected a list for the plane normal vector, got: {type(plane_normal)}." - ) - return PlaneTerrain(plane_normal=plane_normal) + plane_normal = jnp.array(plane_normal, dtype=float) + plane_height_over_origin = jnp.array(plane_height_over_origin, dtype=float) + + if plane_normal.shape != (3,): + msg = "Expected a 3D vector for the plane normal, got '{}'." + raise ValueError(msg.format(plane_normal.shape)) - def height(self, x: float, y: float) -> float: + # Make sure that the plane normal is a unit vector. + plane_normal = plane_normal / jnp.linalg.norm(plane_normal) + + return PlaneTerrain( + z=float(plane_height_over_origin), + plane_normal=tuple(plane_normal.tolist()), + ) + + def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: """ Compute the height of the terrain at a specific (x, y) location on a plane. Args: - x (float): The x-coordinate of the location. - y (float): The y-coordinate of the location. + x: The x-coordinate of the location. + y: The y-coordinate of the location. Returns: - float: The height of the terrain at the specified location on the plane. + The height of the terrain at the specified location on the plane. """ - a, b, c = self.plane_normal - return -(a * x + b * y) / c + # Equation of the plane: A x + B y + C z + D = 0 + # Normal vector coordinates: (A, B, C) + # The height over the origin: -D/C + + # Get the plane equation coefficients from the terrain normal. + A, B, C = self.plane_normal + + # Compute the final coefficient D considering the terrain height. + D = -C * self.z + + # Invert the plane equation to get the height at the given (x, y) coordinates. + return jnp.array(-(A * x + B * y + D) / C).astype(float) + + def __hash__(self) -> int: + + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + hash(self.z), + HashedNumpyArray.hash_of_array( + array=jnp.array(self.plane_normal, dtype=float) + ), + ) + ) + + def __eq__(self, other: PlaneTerrain) -> bool: + + if not isinstance(other, PlaneTerrain): + return False + + if not ( + jnp.allclose(self.z, other.z) + and jnp.allclose( + jnp.array(self.plane_normal, dtype=float), + jnp.array(other.plane_normal, dtype=float), + ) + ): + return False + + return True diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 5d56467c0..9b392836d 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -7,14 +7,14 @@ # JAX types # ========= -ScalarJax = jax.Array -IntJax = ScalarJax -BoolJax = ScalarJax -FloatJax = ScalarJax +Array = jax.Array +Scalar = Array +Vector = Array +Matrix = Array -ArrayJax = jax.Array -VectorJax = ArrayJax -MatrixJax = ArrayJax +Int = Scalar +Bool = Scalar +Float = Scalar PyTree = ( dict[Hashable, "PyTree"] | list["PyTree"] | tuple["PyTree"] | None | jax.Array | Any @@ -24,19 +24,11 @@ # Mixed JAX / NumPy types # ======================= -Array = jax.typing.ArrayLike -Scalar = Array -Vector = Array -Matrix = Array +ArrayLike = jax.typing.ArrayLike | tuple +ScalarLike = int | float | Scalar | ArrayLike +VectorLike = Vector | ArrayLike | tuple +MatrixLike = Matrix | ArrayLike -Int = int | IntJax -Bool = bool | ArrayJax -Float = float | FloatJax - -ScalarLike = Scalar | int | float -ArrayLike = Array -VectorLike = Vector -MatrixLike = Matrix -IntLike = Int -BoolLike = Bool -FloatLike = Float +IntLike = int | Int | jax.typing.ArrayLike +BoolLike = bool | Bool | jax.typing.ArrayLike +FloatLike = float | Float | jax.typing.ArrayLike diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 850959840..0b3801b1e 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -134,7 +134,7 @@ def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]: not a numpy-like array. """ - return tuple( # noqa + return tuple( leaf.shape if hasattr(leaf, "shape") else None for leaf in jax.tree_util.tree_leaves(tree) if hasattr(leaf, "shape") @@ -326,7 +326,7 @@ def replace(self: Self, validate: bool = True, **kwargs) -> Self: return obj - def flatten(self) -> jtp.VectorJax: + def flatten(self) -> jtp.Vector: """ Flatten the object into a 1D vector. @@ -337,7 +337,7 @@ def flatten(self) -> jtp.VectorJax: return self.flatten_fn()(self) @classmethod - def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]: + def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.Vector]: """ Return a function to flatten the object into a 1D vector. @@ -347,7 +347,7 @@ def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]: return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0] - def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]: + def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]: """ Return a function to unflatten a 1D vector into the object. diff --git a/src/jaxsim/utils/wrappers.py b/src/jaxsim/utils/wrappers.py index 1750d2c37..1944bf79f 100644 --- a/src/jaxsim/utils/wrappers.py +++ b/src/jaxsim/utils/wrappers.py @@ -106,7 +106,11 @@ def __eq__(self, other: HashedNumpyArray) -> bool: return False if self.large_array: - return np.array_equal(self.array, other.array) + return np.allclose( + self.array, + other.array, + **({dict(atol=self.precision)} if self.precision is not None else {}), + ) return hash(self) == hash(other) diff --git a/tests/conftest.py b/tests/conftest.py index b16d08208..a75d3890e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -227,12 +227,12 @@ def jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSimModel: model_full = jaxsim_model_ergocub - # Get the names of the joints to keep + # Get the names of the joints to keep. reduced_joints = tuple( j for j in model_full.joint_names() if "camera" not in j - # Remove head and hands + # Remove head and hands. and "neck" not in j and "wrist" not in j and "thumb" not in j @@ -240,7 +240,7 @@ def jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSimModel: and "middle" not in j and "ring" not in j and "pinkie" not in j - # Remove upper body + # Remove upper body. and "torso" not in j and "elbow" not in j and "shoulder" not in j ) diff --git a/tests/test_api_data.py b/tests/test_api_data.py index 12d9b8260..9db5e6561 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -68,7 +68,7 @@ def test_data_switch_velocity_representation( new_base_linear_velocity = jnp.array([1.0, -2.0, 3.0]) old_base_linear_velocity = data.state.physics_model.base_linear_velocity - # The following should not change the original `data` object since it raises + # The following should not change the original `data` object since it raises. with pytest.raises(RuntimeError): with data.switch_velocity_representation( velocity_representation=VelRepr.Inertial @@ -81,7 +81,7 @@ def test_data_switch_velocity_representation( old_base_linear_velocity ) - # The following instead should result to an updated `data` object + # The following instead should result to an updated `data` object. with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): with data.mutable_context(mutability=Mutability.MUTABLE): data.state.physics_model.base_linear_velocity = new_base_linear_velocity diff --git a/tests/test_api_model.py b/tests/test_api_model.py index c89a700af..24654f726 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -286,7 +286,7 @@ def test_model_rbda( # Tests # ===== - # Support both fixed-base and floating-base models by slicing the first six rows + # Support both fixed-base and floating-base models by slicing the first six rows. sl = np.s_[0:] if model.floating_base() else np.s_[6:] # Mass matrix @@ -494,7 +494,7 @@ def test_model_fd_id_consistency( # Tests # ===== - # Create random references (joint torques and link forces) + # Create random references (joint torques and link forces). _, subkey1, subkey2 = jax.random.split(key, num=3) references = js.references.JaxSimModelReferences.build( model=model, @@ -504,7 +504,7 @@ def test_model_fd_id_consistency( velocity_representation=data.velocity_representation, ) - # Remove the force applied to the base link if the model is fixed-base + # Remove the force applied to the base link if the model is fixed-base. if not model.floating_base(): references = references.apply_link_forces( forces=jnp.atleast_2d(jnp.zeros(6)), @@ -514,7 +514,7 @@ def test_model_fd_id_consistency( additive=False, ) - # Compute forward dynamics with ABA + # Compute forward dynamics with ABA. v̇_WB_aba, s̈_aba = js.model.forward_dynamics_aba( model=model, data=data, @@ -522,7 +522,7 @@ def test_model_fd_id_consistency( link_forces=references.link_forces(model=model, data=data), ) - # Compute forward dynamics with CRB + # Compute forward dynamics with CRB. v̇_WB_crb, s̈_crb = js.model.forward_dynamics_crb( model=model, data=data, diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 31281ce9f..702064f49 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -320,7 +320,7 @@ def close_over_inputs_and_parameters( order=AD_ORDER, modes=["rev", "fwd"], eps=ε, - # On GPU, the tolerance needs to be increased + # On GPU, the tolerance needs to be increased. rtol=0.02 if "gpu" in {d.platform for d in p.devices()} else None, ) @@ -373,14 +373,14 @@ def test_ad_integration( # Function exposing only the parameters to be differentiated. def step( - W_p_B: jax.typing.ArrayLike, - W_Q_B: jax.typing.ArrayLike, - s: jax.typing.ArrayLike, - W_v_WB: jax.typing.ArrayLike, - ṡ: jax.typing.ArrayLike, - m: jax.typing.ArrayLike, - τ: jax.typing.ArrayLike, - W_f_L: jax.typing.ArrayLike, + W_p_B: jtp.Vector, + W_Q_B: jtp.Vector, + s: jtp.Vector, + W_v_WB: jtp.Vector, + ṡ: jtp.Vector, + m: jtp.Vector, + τ: jtp.Vector, + W_f_L: jtp.Matrix, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: # When JAX tests against finite differences, the injected ε will make the diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 7040fcab7..3dc0df833 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -124,26 +124,26 @@ class KinDynComputations: @staticmethod def build( urdf: pathlib.Path | str, - considered_joints: list[str] = None, + considered_joints: list[str] | None = None, vel_repr: VelRepr = VelRepr.Inertial, gravity: npt.NDArray = np.array([0, 0, -10.0]), removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: - # Read the URDF description + # Read the URDF description. urdf_string = urdf.read_text() if isinstance(urdf, pathlib.Path) else urdf - # Create the model loader + # Create the model loader. mdl_loader = idt.ModelLoader() - # Handle removed_joint_positions if None + # Handle removed_joint_positions if None. removed_joint_positions = ( {str(name): float(pos) for name, pos in removed_joint_positions.items()} if removed_joint_positions is not None else dict() ) - # Load the URDF description + # Load the URDF description. if not ( mdl_loader.loadModelFromString(urdf_string) if considered_joints is None @@ -153,7 +153,7 @@ def build( ): raise RuntimeError("Failed to load URDF description") - # Create KinDynComputations and insert the model + # Create KinDynComputations and insert the model. kindyn = idt.KinDynComputations() if not kindyn.loadRobotModel(mdl_loader.model()): @@ -165,7 +165,7 @@ def build( VelRepr.Mixed: idt.MIXED_REPRESENTATION, } - # Configure the frame representation + # Configure the frame representation. if not kindyn.setFrameVelocityRepresentation(vel_repr_to_idyntree[vel_repr]): raise RuntimeError("Failed to set the frame representation") @@ -225,7 +225,7 @@ def set_robot_state( if not self.kin_dyn.setRobotState(world_H_base, s, v_WB, s_dot, g): raise RuntimeError("Failed to set the robot state") - # Update stored gravity + # Update stored gravity. self.gravity = gravity def dofs(self) -> int: