Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize minor changes for v0.4 release #186

Merged
merged 29 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
df1a74e
Remove unused imports from notebooks
flferretti Jun 21, 2024
1446fbe
Notebook minor improvements and update to latest API
flferretti Jun 21, 2024
4b089cc
Remove unused type guard from match-statement
flferretti Jun 21, 2024
2c702f0
Avoid implicit `Optional`
flferretti Jun 21, 2024
47c6ba9
Remove unused `noqa` directives
flferretti Jun 21, 2024
a87e918
Avoid `+` operator to concatenate collections
flferretti Jun 21, 2024
9f5cda7
Remove unnecessary iterable allocation for first element
flferretti Jun 21, 2024
9a8b0b3
Add `ruff` specific group check
flferretti Jun 21, 2024
ede7182
Use `jaxsim.typing` for already defined types
flferretti Jun 25, 2024
999e756
Remove circular import in `jaxsim.parsers`
flferretti Jun 25, 2024
e0f70f6
Simplify logic to extract quaternion wxyz from SO3
flferretti Jun 26, 2024
02140e9
Fix LaTeX math in raw strings
flferretti Jun 26, 2024
73e62b4
Update ignored groups in `ruff` check
flferretti Jun 26, 2024
4351b7b
Use `jaxsim.math.Quaternion` for extracting quaternions
flferretti Jun 27, 2024
fa9046d
Update typehints in `integrators` modules
flferretti Jun 27, 2024
0ee6ee9
Prefer `jaxsim.math` operations to `jaxlie`
flferretti Jun 27, 2024
5114371
Address suggestions from code review
flferretti Jun 28, 2024
f725b21
Homogenize comments style
flferretti Jun 28, 2024
0283f55
Minor docstring fix
diegoferigo Jun 28, 2024
35204f1
Initialize default arguments of Transform.from_rotation_and_translation
diegoferigo Jun 28, 2024
43eee5f
Fix typing of SoftContacts.compute_contact_forces
diegoferigo Jun 28, 2024
390c928
Update comparison of large numpy arrays
diegoferigo Jun 28, 2024
92cc272
Support building PlaneTerrain objects not passing through the origin
diegoferigo Jul 2, 2024
e86255e
Add explicit hash and eq to terrain classes
diegoferigo Jul 2, 2024
1ec9506
Make SoftContacts.terrain a static attribute
diegoferigo Jul 2, 2024
67f65eb
Update jaxsim.typing module
diegoferigo Jun 28, 2024
7036086
Remove usage of kwargs in jax.numpy.clip
diegoferigo Jul 1, 2024
c2dfd3f
Clarify docstrings and cleanup unused variable
diegoferigo Jul 2, 2024
89c5056
Expand the output path in MujocoVideoRecorder.write_video
diegoferigo Jul 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions examples/PD_controller.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"outputs": [],
"source": [
"# @title Imports and setup\n",
"from IPython.display import clear_output, HTML, display\n",
"from IPython.display import clear_output\n",
"import sys\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
Expand Down Expand Up @@ -100,6 +100,8 @@
"from jaxsim import integrators\n",
"\n",
"dt = 0.01\n",
"integration_time = 5.0\n",
"num_steps = int(integration_time / dt)\n",
"\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_urdf_string, is_urdf=True\n",
Expand Down Expand Up @@ -150,7 +152,7 @@
"source": [
"# @title Set up MuJoCo renderer\n",
"\n",
"from jaxsim.mujoco import RodModelToMjcf, MujocoModelHelper, MujocoVideoRecorder\n",
"from jaxsim.mujoco import MujocoModelHelper, MujocoVideoRecorder\n",
"from jaxsim.mujoco.loaders import UrdfToMjcf\n",
"\n",
"import os\n",
Expand Down Expand Up @@ -227,7 +229,7 @@
"source": [
"import mediapy as media\n",
"\n",
"for _ in range(500):\n",
"for _ in range(num_steps):\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
Expand Down Expand Up @@ -299,7 +301,7 @@
"metadata": {},
"outputs": [],
"source": [
"for _ in range(500):\n",
"for _ in range(num_steps):\n",
" control_torques = pd_controller(\n",
" data=data,\n",
" q_d=jnp.array([0.0, 0.0]),\n",
Expand Down
44 changes: 22 additions & 22 deletions examples/Parallel_computing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
"# @title Imports and setup\n",
"import sys\n",
"\n",
"from IPython.display import HTML, clear_output, display\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
Expand All @@ -40,15 +38,12 @@
"%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
"\n",
"import time\n",
"from typing import Dict, Tuple\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_dataclasses\n",
"import rod\n",
"from rod.builder.primitives import SphereBuilder\n",
"\n",
"import jaxsim.typing as jtp\n",
"from jaxsim import logging\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.INFO)\n",
Expand Down Expand Up @@ -105,7 +100,7 @@
"from jaxsim import integrators\n",
"\n",
"dt = 0.001\n",
"integration_time = 1500\n",
"integration_time = 1.5 # seconds\n",
"\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string\n",
Expand All @@ -129,7 +124,9 @@
"\n",
"By default, in JaxSim a sphere primitive has 250 collision points. This can be modified by setting the `JAXSIM_COLLISION_SPHERE_POINTS` environment variable.\n",
"\n",
"Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values."
"Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values. \n",
"\n",
"Eventually, you can specify the maximum penetration depth of the sphere into the terrain by passing `max_penetraion` to the `estimate_good_soft_contacts_parameters` function."
]
},
{
Expand All @@ -139,8 +136,10 @@
"outputs": [],
"source": [
"data = data.replace(\n",
" soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n",
" model, number_of_active_collidable_points_steady_state=3\n",
" contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n",
" model=model,\n",
" number_of_active_collidable_points_steady_state=3,\n",
" max_penetration=None,\n",
" )\n",
")"
]
Expand All @@ -149,7 +148,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height."
"Let's create a position vector for a 4x4 grid. Every sphere will be placed at a different height."
]
},
{
Expand All @@ -169,14 +168,13 @@
"def grid(edge_len, envs_per_row):\n",
" edge = jnp.linspace(-edge_len, edge_len, envs_per_row)\n",
" xx, yy = jnp.meshgrid(edge, edge)\n",
"\n",
" poses = [\n",
" [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]\n",
" for i in range(xx.shape[0])\n",
" for j in range(yy.shape[0])\n",
" ]\n",
"\n",
" return jnp.array(poses)\n",
" zz = 0.2 + 0.1 * (\n",
" jnp.arange(envs_per_row**2) % envs_per_row\n",
" + jnp.arange(envs_per_row**2) // envs_per_row\n",
" )\n",
" zz = zz.reshape(envs_per_row, envs_per_row)\n",
" poses = jnp.stack([xx, yy, zz], axis=-1).reshape(envs_per_row**2, 3)\n",
" return poses\n",
"\n",
"\n",
"logging.info(f\"Simulating {envs_per_row**2} environments\")\n",
Expand All @@ -200,11 +198,13 @@
"def simulate(\n",
" data: js.data.JaxSimModelData, integrator_state: dict, pose: jnp.array\n",
") -> tuple:\n",
"\n",
" # Set the base position to the initial pose\n",
" data = data.reset_base_position(base_position=pose)\n",
"\n",
" # Create a list to store the base position over time\n",
" x_t_i = []\n",
"\n",
" for _ in range(integration_time):\n",
" for _ in range(int(integration_time // dt)):\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
Expand Down Expand Up @@ -243,15 +243,15 @@
"# Run and time the simulation\n",
"now = time.perf_counter()\n",
"\n",
"x_t = simulate_vectorized(data, integrator_state, poses[:, 0])\n",
"x_t = simulate_vectorized(data, integrator_state, poses)\n",
"\n",
"comp_time = time.perf_counter() - now\n",
"\n",
"logging.info(\n",
" f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n",
")\n",
"logging.info(\n",
" f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 *integration_time/comp_time):.2f}\"\n",
" f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 * integration_time / comp_time):.2f}\"\n",
")"
]
},
Expand Down
21 changes: 11 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,26 @@ preview = true

[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
select = [
"B",
"E",
"F",
"I",
"W",
"RUF",
"YTT",
]

ignore = [
"B008", # Function call in default argument
"B024", # Abstract base class without abstract methods
"B904", # Raise without from inside exception
"B905", # Zip without explicit strict
"E402", # Module level import not at top of file
"E501", # Line too long
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # Ambiguous variable name
"F841", # Local variable is assigned to but never used
"I001", # Import block is unsorted or unformatted
]
select = [
"B",
"E",
"F",
"I",
"W",
"YTT",
"RUF003", # Ambigous unicode character in comment
]

[tool.ruff.lint.per-file-ignores]
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ def _is_editable() -> bool:
import pathlib
import site

# Get the ModuleSpec of jaxsim
# Get the ModuleSpec of jaxsim.
jaxsim_spec = importlib.util.find_spec(name="jaxsim")

# This can be None. If it's None, assume non-editable installation.
if jaxsim_spec.origin is None:
return False

# Get the folder containing the jaxsim package
# Get the folder containing the jaxsim package.
jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent)

# The installation is editable if the package dir is not in any {site|dist}-packages
# The installation is editable if the package dir is not in any {site|dist}-packages.
return jaxsim_package_dir not in site.getsitepackages()


Expand Down Expand Up @@ -82,10 +82,10 @@ def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))


# Configure JAX
# Configure JAX.
_jnp_options()

# Initialize the numpy print options
# Initialize the numpy print options.
_np_options()

del _jnp_options
Expand Down
7 changes: 3 additions & 4 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.math
Expand Down Expand Up @@ -28,7 +27,7 @@ def com_position(

W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_B = data.base_transform()
B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)

def B_p̃_LCoM(i) -> jtp.Vector:
m = js.link.mass(model=model, link_index=i)
Expand Down Expand Up @@ -179,9 +178,9 @@ def locked_centroidal_spatial_inertia(
case _:
raise ValueError(data.velocity_representation)

B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G)
B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G

B_Xv_G = B_H_G.adjoint()
B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
G_Xf_B = B_Xv_G.transpose()

return G_Xf_B @ B_Mbb_B @ B_Xv_G
Expand Down
22 changes: 11 additions & 11 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.utils import JaxsimDataclass, Mutability

try:
Expand Down Expand Up @@ -59,7 +59,7 @@ def switch_velocity_representation(

try:

# First, we replace the velocity representation
# First, we replace the velocity representation.
with self.mutable_context(
mutability=Mutability.MUTABLE_NO_VALIDATION,
restore_after_exception=True,
Expand Down Expand Up @@ -97,7 +97,7 @@ def inertial_to_other_representation(
array: The 6D quantity to convert.
other_representation: The representation to convert to.
transform:
The `math:W \mathbf{H}_O` transform, where `math:O` is the
The :math:`W \mathbf{H}_O` transform, where :math:`O` is the
reference frame of the other representation.
is_force: Whether the quantity is a 6D force or a 6D velocity.

Expand All @@ -122,11 +122,11 @@ def inertial_to_other_representation(
case VelRepr.Body:

if not is_force:
O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint()
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
O_array = O_Xv_W @ W_array

else:
O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T
O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
O_array = O_Xf_W @ W_array

return O_array
Expand All @@ -136,11 +136,11 @@ def inertial_to_other_representation(
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)

if not is_force:
OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint()
OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
OW_array = OW_Xv_W @ W_array

else:
OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose()
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
OW_array = OW_Xf_W @ W_array

return OW_array
Expand Down Expand Up @@ -190,11 +190,11 @@ def other_representation_to_inertial(
O_array = array

if not is_force:
W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint()
W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
W_array = W_Xv_O @ O_array

else:
W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T
W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
W_array = W_Xf_O @ O_array

return W_array
Expand All @@ -205,11 +205,11 @@ def other_representation_to_inertial(
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)

if not is_force:
W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint()
W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
W_array = W_Xv_BW @ BW_array

else:
W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T
W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
W_array = W_Xf_BW @ BW_array

return W_array
Expand Down
14 changes: 11 additions & 3 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsParams
from jaxsim.rbda.contacts.soft import SoftContactsParams

from .common import VelRepr

Expand Down Expand Up @@ -137,9 +137,17 @@ def collidable_point_dynamics(
# all collidable points belonging to the robot.
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)

# Import privately the soft contacts classes.
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState

# Build the soft contact model.
match model.contact_model:
case s if isinstance(s, SoftContacts):

case SoftContacts():

assert isinstance(model.contact_model, SoftContacts)
flferretti marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(data.state.contact, SoftContactsState)

# Build the contact model.
soft_contacts = SoftContacts(
parameters=data.contacts_params, terrain=model.terrain
Expand Down Expand Up @@ -337,7 +345,7 @@ def jacobian(
The output velocity representation of the free-floating jacobian.

Returns:
The stacked (6+n) free-floating jacobians of the frames associated to the
The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
collidable points.

Note:
Expand Down
Loading