From f55d4d97997295f3419e33e3b32fa2fa37078434 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 16 Dec 2024 15:11:03 +0100 Subject: [PATCH] Save some kindyn computation in `JaxSimModelData` --- examples/jaxsim_as_physics_engine.ipynb | 6 +- examples/jaxsim_for_robot_controllers.ipynb | 2 + src/jaxsim/api/com.py | 4 +- src/jaxsim/api/common.py | 49 ++ src/jaxsim/api/contact.py | 15 +- src/jaxsim/api/data.py | 490 ++++++++++++++++---- src/jaxsim/api/link.py | 12 +- src/jaxsim/api/model.py | 57 +-- src/jaxsim/api/references.py | 4 +- src/jaxsim/mujoco/utils.py | 2 +- src/jaxsim/rbda/aba.py | 8 +- src/jaxsim/rbda/collidable_points.py | 17 +- src/jaxsim/rbda/contacts/common.py | 2 +- src/jaxsim/rbda/contacts/visco_elastic.py | 6 +- src/jaxsim/rbda/crba.py | 20 +- src/jaxsim/rbda/forward_kinematics.py | 23 +- src/jaxsim/rbda/jacobian.py | 22 +- src/jaxsim/rbda/rnea.py | 8 +- src/jaxsim/typing.py | 2 +- tests/test_api_contact.py | 97 ++-- tests/test_api_frame.py | 5 +- tests/test_api_link.py | 5 +- tests/test_api_model.py | 78 ++-- tests/test_automatic_differentiation.py | 56 ++- tests/test_pytree.py | 2 +- tests/test_simulations.py | 27 +- 26 files changed, 689 insertions(+), 330 deletions(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index b8ec18cfb..8e56136be 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -299,7 +299,7 @@ " )\n", ")(jnp.vstack(subkeys))\n", "\n", - "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])" + "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])" ] }, { @@ -398,7 +398,7 @@ "# This operation is called 'tree transpose' in JAX.\n", "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n", "\n", - "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")" + "print(f\"W_p_B: shape={data_trajectory.base_position.shape}\")" ] }, { @@ -412,7 +412,7 @@ "import matplotlib.pyplot as plt\n", "\n", "\n", - "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n", + "plt.plot(T, data_trajectory.base_position[:, 0:5, 2])\n", "plt.grid(True)\n", "plt.xlabel(\"Time [s]\")\n", "plt.ylabel(\"Height [m]\")\n", diff --git a/examples/jaxsim_for_robot_controllers.ipynb b/examples/jaxsim_for_robot_controllers.ipynb index dd9353e3f..5d6d480fc 100644 --- a/examples/jaxsim_for_robot_controllers.ipynb +++ b/examples/jaxsim_for_robot_controllers.ipynb @@ -239,6 +239,8 @@ "# Reset the state to the random joint positions.\n", "data = data_zero.reset_joint_positions(positions=random_joint_positions)\n", "\n", + "# Update the kyn_dyn cache.\n", + "data = data.update_kyn_dyn(model=model)\n", "\n", "for _ in T:\n", "\n", diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index f2122ced1..401706739 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -26,7 +26,7 @@ def com_position( m = js.model.total_mass(model=model) - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics W_H_B = data.base_transform() B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) @@ -269,7 +269,7 @@ def bias_acceleration( """ # Compute the pose of all links with forward kinematics. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics # Compute the bias acceleration of all links by zeroing the generalized velocity # in the active representation. diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 7d723120e..163f421ef 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -232,3 +232,52 @@ def other_representation_to_inertial( case _: raise ValueError(other_representation) + + +def convert_mass_matrix( + M: jtp.Matrix, + base_transform: jtp.Matrix, + dofs: jtp.Int, + velocity_representation: VelRepr, +): + + # The mass matrix is always save in body-fixed representation. + + match velocity_representation: + case VelRepr.Body: + return M + + case VelRepr.Inertial: + + B_X_W = Adjoint.from_transform(transform=base_transform, inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(dofs)) + + return invT.T @ M @ invT + + case VelRepr.Mixed: + + BW_H_B = base_transform.at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(dofs)) + + return invT.T @ M @ invT + + +def convert_jacobian( + J: jtp.Matrix, + base_transform: jtp.Matrix, + dofs: jtp.Int, + velocity_representation: VelRepr, +): + # TODO (flferretti): save actual Jacobian instead of full doubly left and perform conversion. + return J + + +def convert_jacobian_derivative( + Jd: jtp.Matrix, + base_transform: jtp.Matrix, + dofs: jtp.Int, + velocity_representation: VelRepr, +): + # TODO (flferretti): save actual Jacobian derivative instead of full doubly left and perform conversion. + return Jd diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 294413f7e..11ec6dc76 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -40,15 +40,8 @@ def collidable_point_kinematics( # Switch to inertial-fixed since the RBDAs expect velocities in this representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( - model=model, - base_position=data.base_position(), - base_quaternion=data.base_orientation(dcm=False), - joint_positions=data.joint_positions(model=model), - base_linear_velocity=data.base_velocity()[0:3], - base_angular_velocity=data.base_velocity()[3:6], - joint_velocities=data.joint_velocities(model=model), - ) + W_p_Ci = data.kyn_dyn.collidable_point_positions + W_ṗ_Ci = data.kyn_dyn.collidable_point_velocities return W_p_Ci, W_ṗ_Ci @@ -460,7 +453,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt )[indices_of_enabled_collidable_points] # Get the transforms of the parent link of all collidable points. - W_H_L = js.model.forward_kinematics(model=model, data=data)[ + W_H_L = data.kyn_dyn.forward_kinematics[ parent_link_idx_of_enabled_collidable_points ] @@ -612,7 +605,7 @@ def jacobian_derivative( ] # Get the transforms of all the parent links. - W_H_Li = js.model.forward_kinematics(model=model, data=data) + W_H_Li = data.kyn_dyn.forward_kinematics # ===================================================== # Compute quantities to adjust the input representation diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index b880547b9..0802df3d4 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -13,11 +13,15 @@ import jaxsim.math import jaxsim.rbda import jaxsim.typing as jtp -from jaxsim.utils import Mutability from jaxsim.utils.tracing import not_tracing from . import common -from .common import VelRepr +from .common import ( + VelRepr, + convert_jacobian, + convert_jacobian_derivative, + convert_mass_matrix, +) from .ode_data import ODEState try: @@ -26,6 +30,130 @@ from typing_extensions import Self +class KynDynProxy: + """ + Proxy class for KynDynComputation that ensures attribute-specific + velocity representation consistency. + """ + + _data: JaxSimModelData + _kyn_dyn: KynDynComputation + + def __init__(self, data, kyn_dyn): + self._data = data + self._kyn_dyn = kyn_dyn + + def __convert_attribute(self, value, name): + + if name in [ + "motion_subspaces", + "joint_transforms", + "forward_kinematics", + "velocity_representation", + "link_body_transforms", + "collidable_point_positions", + "collidable_point_velocities", + ]: + return value + + W_R_B = jaxsim.math.Quaternion.to_dcm( + self._data.state.physics_model.base_quaternion + ) + W_p_B = jnp.vstack(self._data.state.physics_model.base_position) + + W_H_B = jnp.vstack( + [ + jnp.block([W_R_B, W_p_B]), + jnp.array([0, 0, 0, 1]), + ] + ) + + match name: + + case "jacobian_full_doubly_left": + if ( + self._data.velocity_representation + != self._kyn_dyn.velocity_representation + ): + value = convert_jacobian( + J=value, + dofs=len(self._data.state.physics_model.joint_positions), + base_transform=W_H_B, + velocity_representation=self._data.velocity_representation, + ) + + case "jacobian_derivative_full_doubly_left": + if ( + self._data.velocity_representation + != self._kyn_dyn.velocity_representation + ): + value = convert_jacobian_derivative( + Jd=value, + dofs=len(self._data.state.physics_model.joint_positions), + base_transform=W_H_B, + velocity_representation=self._data.velocity_representation, + ) + + case "mass_matrix": + if ( + self._data.velocity_representation + != self._kyn_dyn.velocity_representation + ): + value = convert_mass_matrix( + M=value, + dofs=len(self._data.state.physics_model.joint_positions), + base_transform=W_H_B, + velocity_representation=self._data.velocity_representation, + ) + + case _: + raise AttributeError( + f"'{type(self._kyn_dyn).__name__}' object has no attribute '{name}'" + ) + + return value + + def __getattr__(self, name: str): + + if name in ["_data", "_kyn_dyn"]: + return super().__getattribute__(name) + + value = getattr(self._kyn_dyn, name) + + return self.__convert_attribute(value=value, name=name) + + def __setattr__(self, name, value): + + if name in ["_data", "_kyn_dyn"]: + return super().__setattr__(name, value) + + value = self.__convert_attribute(value=value, name=name) + + # Push the update to JaxSimModelData. + self._data._update_kyn_dyn(name, value) + + +@jax_dataclasses.pytree_dataclass +class KynDynComputation(common.ModelDataWithVelocityRepresentation): + motion_subspaces: jtp.Matrix + + joint_transforms: jtp.Matrix + + forward_kinematics: jtp.Matrix + + jacobian_full_doubly_left: jtp.Matrix + + jacobian_derivative_full_doubly_left: jtp.Matrix + + link_body_transforms: jtp.Matrix + + collidable_point_positions: jtp.Matrix + + collidable_point_velocities: jtp.Matrix + + mass_matrix: jtp.Matrix + + @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): """ @@ -38,6 +166,41 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False) + _kyn_dyn: KynDynComputation | None = dataclasses.field(default=None, repr=False) + + _kyn_dyn_stale: jtp.Array = dataclasses.field( + repr=False, default_factory=lambda: jnp.array(0, dtype=bool) + ) + + @property + def kyn_dyn(self): + + jaxsim.exceptions.raise_runtime_error_if( + self._kyn_dyn_stale, + msg="The `kyn_dyn` cache is invalid, please call the `update_kyn_dyn` method after resetting the `JaxSimModelData` state.", + ) + + # Return proxy object that handles attribute-specific conversions. + return KynDynProxy(data=self, kyn_dyn=self._kyn_dyn) + + @kyn_dyn.setter + def kyn_dyn(self, new_kyn_dyn: KynDynComputation): + + if not isinstance(new_kyn_dyn, KynDynComputation): + raise ValueError("kyn_dyn must be an instance of KynDynComputation") + + self._kyn_dyn = new_kyn_dyn + + def _update_kyn_dyn(self, name: str, value): + """ + Update a specific attribute of `_kyn_dyn` and reset the instance immutably. + """ + # Replace the specific attribute in `_kyn_dyn`. + updated_kyn_dyn = self._kyn_dyn.replace(**{name: value}) + + # Update `_kyn_dyn` immutably. + object.__setattr__(self, "_kyn_dyn", updated_kyn_dyn) + def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray @@ -232,11 +395,76 @@ def build( else: contacts_params = model.contact_model._parameters_class() + base_orientation = jaxsim.math.Quaternion.to_dcm(base_quaternion) + + base_transform = jnp.vstack( + [ + jnp.block([base_orientation, jnp.vstack(base_position)]), + jnp.array([0, 0, 0, 1]), + ] + ) + + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=joint_positions, base_transform=base_transform + ) + + M = jaxsim.rbda.crba( + model=model, + joint_positions=joint_positions, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + J, _ = jaxsim.rbda.jacobian_full_doubly_left( + model=model, + joint_positions=joint_positions, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + J̇, B_H_LL = jaxsim.rbda.jacobian_derivative_full_doubly_left( + model=model, + joint_positions=joint_positions, + joint_velocities=joint_velocities, + ) + + W_H_LL = jaxsim.rbda.forward_kinematics_model( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + joint_transforms=i_X_λ, + ) + + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + base_linear_velocity=base_linear_velocity, + base_angular_velocity=base_angular_velocity, + joint_velocities=joint_velocities, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + + kyn_dyn = KynDynComputation( + velocity_representation=velocity_representation, + jacobian_full_doubly_left=J, + jacobian_derivative_full_doubly_left=J̇, + link_body_transforms=B_H_LL, + motion_subspaces=S, + joint_transforms=i_X_λ, + mass_matrix=M, + forward_kinematics=W_H_LL, + collidable_point_positions=W_p_Ci, + collidable_point_velocities=W_ṗ_Ci, + ) + return JaxSimModelData( state=ode_state, gravity=gravity, contacts_params=contacts_params, velocity_representation=velocity_representation, + _kyn_dyn=kyn_dyn, ) # ================== @@ -287,12 +515,6 @@ def joint_positions( return self.state.physics_model.joint_positions - if not_tracing(self.state.physics_model.joint_positions) and not self.valid( - model=model - ): - msg = "The data object is not compatible with the provided model" - raise ValueError(msg) - joint_idxs = ( js.joint.names_to_idxs(joint_names=joint_names, model=model) if joint_names is not None @@ -335,12 +557,6 @@ def joint_velocities( return self.state.physics_model.joint_velocities - if not_tracing(self.state.physics_model.joint_velocities) and not self.valid( - model=model - ): - msg = "The data object is not compatible with the provided model" - raise ValueError(msg) - joint_idxs = ( js.joint.names_to_idxs(joint_names=joint_names, model=model) if joint_names is not None @@ -349,8 +565,7 @@ def joint_velocities( return self.state.physics_model.joint_velocities[joint_idxs] - @js.common.named_scope - @jax.jit + @property def base_position(self) -> jtp.Vector: """ Get the base position. @@ -359,7 +574,7 @@ def base_position(self) -> jtp.Vector: The base position. """ - return self.state.physics_model.base_position.squeeze() + return self.state.physics_model.base_position @js.common.named_scope @functools.partial(jax.jit, static_argnames=["dcm"]) @@ -401,7 +616,7 @@ def base_transform(self) -> jtp.Matrix: """ W_R_B = self.base_orientation(dcm=True) - W_p_B = jnp.vstack(self.base_position()) + W_p_B = jnp.vstack(self.base_position) return jnp.vstack( [ @@ -429,16 +644,12 @@ def base_velocity(self) -> jtp.Vector: W_H_B = self.base_transform() - return ( - JaxSimModelData.inertial_to_other_representation( - array=W_v_WB, - other_representation=self.velocity_representation, - transform=W_H_B, - is_force=False, - ) - .squeeze() - .astype(float) - ) + return JaxSimModelData.inertial_to_other_representation( + array=W_v_WB, + other_representation=self.velocity_representation, + transform=W_H_B, + is_force=False, + ).astype(float) @js.common.named_scope @jax.jit @@ -504,6 +715,7 @@ def replace(s: jtp.VectorLike) -> JaxSimModelData: joint_positions=jnp.atleast_1d(s.squeeze()).astype(float) ) ), + _kyn_dyn_stale=jnp.array(1, dtype=bool), ) if model is None: @@ -553,6 +765,7 @@ def replace(ṡ: jtp.VectorLike) -> JaxSimModelData: joint_velocities=jnp.atleast_1d(ṡ.squeeze()).astype(float) ) ), + _kyn_dyn_stale=jnp.array(1, dtype=bool), ) if model is None: @@ -594,6 +807,7 @@ def reset_base_position(self, base_position: jtp.VectorLike) -> Self: base_position=jnp.atleast_1d(base_position.squeeze()).astype(float) ) ), + _kyn_dyn_stale=jnp.array(1, dtype=bool), ) @js.common.named_scope @@ -622,6 +836,7 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self: state=self.state.replace( physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B) ), + _kyn_dyn_stale=jnp.array(1, dtype=bool), ) @js.common.named_scope @@ -754,8 +969,89 @@ def reset_base_velocity( base_angular_velocity=W_v_WB[3:6].squeeze().astype(float), ) ), + _kyn_dyn_stale=jnp.array(1, dtype=bool), + ) + + @js.common.named_scope + @jax.jit + def update_kyn_dyn( + self, + model: js.model.JaxSimModel, + ) -> KynDynComputation: + """ + Updates the `kyn_dyn` attribute of `JaxSimModelData`. + + Args: + model: The model to consider. + + Returns: + An instance of `JaxSimModelData` with the updated `kyn_dyn` attribute. + """ + + base_quaternion = self.base_orientation(dcm=False) + base_velocity = self.base_velocity() + base_transform = self.base_transform() + joint_positions = self.joint_positions() + joint_velocities = self.joint_velocities() + + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=joint_positions, base_transform=base_transform + ) + + M = jaxsim.rbda.crba( + model=model, + joint_positions=joint_positions, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + + J, _ = jaxsim.rbda.jacobian_full_doubly_left( + model=model, + joint_positions=joint_positions, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + + J̇, B_H_LL = jaxsim.rbda.jacobian_derivative_full_doubly_left( + model=model, + joint_positions=joint_positions, + joint_velocities=joint_velocities, + ) + + W_H_LL = jaxsim.rbda.forward_kinematics_model( + model=model, + base_position=base_transform[:3, 3], + base_quaternion=base_quaternion, + joint_positions=joint_positions, + joint_transforms=i_X_λ, ) + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( + model=model, + base_position=base_transform[:3, 3], + base_quaternion=base_quaternion, + joint_positions=joint_positions, + base_linear_velocity=base_velocity[0:3], + base_angular_velocity=base_velocity[3:6], + joint_velocities=joint_velocities, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + + data = self.replace(_kyn_dyn_stale=jnp.array(0, dtype=bool)) + + data.kyn_dyn.jacobian_full_doubly_left = J + data.kyn_dyn.jacobian_derivative_full_doubly_left = J̇ + data.kyn_dyn.link_body_transforms = B_H_LL + data.kyn_dyn.motion_subspaces = S + data.kyn_dyn.joint_transforms = i_X_λ + data.kyn_dyn.mass_matrix = M + data.kyn_dyn.forward_kinematics = W_H_LL + data.kyn_dyn.collidable_point_positions = W_p_Ci + data.kyn_dyn.collidable_point_velocities = W_ṗ_Ci + + return data + @functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"]) def random_model_data( @@ -834,74 +1130,60 @@ def random_model_data( ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float) ṡ_min, ṡ_max = joint_vel_bounds - random_data = JaxSimModelData.zero( - model=model, - **( - dict(velocity_representation=velocity_representation) - if velocity_representation is not None - else {} - ), - ) - - with random_data.mutable_context( - mutability=Mutability.MUTABLE, restore_after_exception=False - ): - - physics_model_state = random_data.state.physics_model + base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max) - physics_model_state.base_position = jax.random.uniform( - key=k1, shape=(3,), minval=p_min, maxval=p_max - ) + base_quaternion = jaxsim.math.Quaternion.to_wxyz( + xyzw=jax.scipy.spatial.transform.Rotation.from_euler( + seq=base_rpy_seq, + angles=jax.random.uniform( + key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max + ), + ).as_quat() + ) - physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz( - xyzw=jax.scipy.spatial.transform.Rotation.from_euler( - seq=base_rpy_seq, - angles=jax.random.uniform( - key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max - ), - ).as_quat() + ( + joint_positions, + joint_velocities, + base_linear_velocity, + base_angular_velocity, + standard_gravity, + contacts_params, + ) = (None,) * 6 + + if model.number_of_joints() > 0: + + s_min, s_max = ( + jnp.array(joint_pos_bounds, dtype=float) + if joint_pos_bounds is not None + else (None, None) ) - if model.number_of_joints() > 0: - - s_min, s_max = ( - jnp.array(joint_pos_bounds, dtype=float) - if joint_pos_bounds is not None - else (None, None) - ) - - physics_model_state.joint_positions = ( - js.joint.random_joint_positions(model=model, key=k3) - if (s_min is None or s_max is None) - else jax.random.uniform( - key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max - ) + joint_positions = ( + js.joint.random_joint_positions(model=model, key=k3) + if (s_min is None or s_max is None) + else jax.random.uniform( + key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max ) + ) - physics_model_state.joint_velocities = jax.random.uniform( - key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max - ) + joint_velocities = jax.random.uniform( + key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max + ) - if model.floating_base(): - physics_model_state.base_linear_velocity = jax.random.uniform( - key=k5, shape=(3,), minval=v_min, maxval=v_max - ) + if model.floating_base(): + base_linear_velocity = jax.random.uniform( + key=k5, shape=(3,), minval=v_min, maxval=v_max + ) - physics_model_state.base_angular_velocity = jax.random.uniform( - key=k6, shape=(3,), minval=ω_min, maxval=ω_max - ) + base_angular_velocity = jax.random.uniform( + key=k6, shape=(3,), minval=ω_min, maxval=ω_max + ) - random_data.gravity = ( - jnp.zeros(3, dtype=random_data.gravity.dtype) - .at[2] - .set( - -jax.random.uniform( - key=k7, - shape=(), - minval=standard_gravity_bounds[0], - maxval=standard_gravity_bounds[1], - ) - ) + standard_gravity = jax.random.uniform( + key=k7, + shape=(), + minval=standard_gravity_bounds[0], + maxval=standard_gravity_bounds[1], ) if contacts_params is None: @@ -912,17 +1194,29 @@ def random_model_data( | jaxsim.rbda.contacts.ViscoElasticContacts, ): - random_data = random_data.replace( - contacts_params=js.contact.estimate_good_contact_parameters( - model=model, standard_gravity=random_data.gravity - ), - validate=False, + contacts_params = js.contact.estimate_good_contact_parameters( + model=model, standard_gravity=standard_gravity ) - else: - random_data = random_data.replace( - contacts_params=model.contact_model._parameters_class(), - validate=False, - ) + contacts_params = (model.contact_model._parameters_class(),) - return random_data + return JaxSimModelData.build( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + joint_velocities=joint_velocities, + base_linear_velocity=base_linear_velocity, + base_angular_velocity=base_angular_velocity, + contacts_params=contacts_params, + **( + {"standard_gravity": standard_gravity} + if standard_gravity is not None + else {} + ), + **( + {"velocity_representation": velocity_representation} + if velocity_representation is not None + else {} + ), + ) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 23b4d3732..ca7a2ded3 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -187,7 +187,7 @@ def transform( idx=link_index, ) - return js.model.forward_kinematics(model=model, data=data)[link_index] + return data.kyn_dyn.forward_kinematics[link_index] @jax.jit @@ -272,10 +272,10 @@ def jacobian( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) - # Compute the doubly-left free-floating full jacobian. - B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left( - model=model, - joint_positions=data.joint_positions(), + # Compute the doubly left free-floating full jacobian. + B_J_full_WX_B, B_H_Li = ( + data.kyn_dyn.jacobian_full_doubly_left, + data.kyn_dyn.link_body_transforms, ) # Compute the actual doubly-left free-floating jacobian of the link. @@ -423,7 +423,7 @@ def jacobian_derivative( ) O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative( - model=model, data=data, output_vel_repr=output_vel_repr + model=model, data=data )[link_index] return O_J̇_WL_I diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 35c327791..b4c4cbb67 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -546,12 +546,7 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp The first axis is the link index. """ - W_H_LL = jaxsim.rbda.forward_kinematics_model( - model=model, - base_position=data.base_position(), - base_quaternion=data.base_orientation(dcm=False), - joint_positions=data.joint_positions(model=model), - ) + W_H_LL = data.kyn_dyn.forward_kinematics return jnp.atleast_3d(W_H_LL).astype(float) @@ -587,9 +582,9 @@ def generalized_free_floating_jacobian( ) # Compute the doubly-left free-floating full jacobian. - B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left( - model=model, - joint_positions=data.joint_positions(), + B_J_full_WX_B, B_H_L = ( + data.kyn_dyn.jacobian_full_doubly_left, + data.kyn_dyn.link_body_transforms, ) # ====================================================================== @@ -713,18 +708,14 @@ def generalized_free_floating_jacobian_derivative( ) # Compute the derivative of the doubly-left free-floating full jacobian. - B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left( - model=model, - joint_positions=data.joint_positions(), - joint_velocities=data.joint_velocities(), + B_J̇_full_WX_B, B_H_L = ( + data.kyn_dyn.jacobian_derivative_full_doubly_left, + data.kyn_dyn.link_body_transforms, ) # The derivative of the equation to change the input and output representations # of the Jacobian derivative needs the computation of the plain link Jacobian. - B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left( - model=model, - joint_positions=data.joint_positions(), - ) + B_J_full_WL_B = data.kyn_dyn.jacobian_full_doubly_left # Compute the actual doubly-left free-floating jacobian derivative of the link # by zeroing the columns not in the path π_B(L) using the boolean κ(i). @@ -979,7 +970,7 @@ def forward_dynamics_aba( # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1005,6 +996,8 @@ def forward_dynamics_aba( joint_forces=τ, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ============= @@ -1172,10 +1165,7 @@ def free_floating_mass_matrix( The free-floating mass matrix of the model. """ - M_body = jaxsim.rbda.crba( - model=model, - joint_positions=data.state.physics_model.joint_positions, - ) + M_body = data.kyn_dyn.mass_matrix match data.velocity_representation: case VelRepr.Body: @@ -1431,7 +1421,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1458,6 +1448,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): joint_accelerations=s̈, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ============= @@ -1513,6 +1505,8 @@ def free_floating_gravity_forces( data.state.physics_model.joint_positions ) + data_rnea = data_rnea.update_kyn_dyn(model=model) + return jnp.hstack( inverse_dynamics( model=model, @@ -1578,6 +1572,8 @@ def free_floating_bias_forces( data.state.physics_model.base_angular_velocity ) + data_rnea = data_rnea.update_kyn_dyn(model=model) + return jnp.hstack( inverse_dynamics( model=model, @@ -1766,7 +1762,7 @@ def average_velocity_jacobian( case VelRepr.Body: GB_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) B_R_W = data.base_orientation(dcm=True).transpose() @@ -1778,7 +1774,7 @@ def average_velocity_jacobian( case VelRepr.Mixed: GW_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) @@ -1980,11 +1976,11 @@ def body_to_other_representation( ) case VelRepr.Inertial: - C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data) + C_H_L = W_H_L = data.kyn_dyn.forward_kinematics L_v_CL = L_v_WL case VelRepr.Mixed: - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) C_H_L = LW_H_L L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 @@ -2254,7 +2250,7 @@ def step( f_L = references.link_forces(model=model, data=data) τ_references = references.joint_force_references(model=model) - # Step the dynamics forward. + # Integrate the system dynamics. state_tf, integrator_metadata_tf = integrator.step( x0=state_t0, t0=t0, @@ -2285,6 +2281,8 @@ def step( # Phase 3: post-step # ================== + data_tf = data_tf.update_kyn_dyn(model=model) + # Post process the simulation state, if needed. match model.contact_model: @@ -2344,6 +2342,9 @@ def step( data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6]) data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:]) + # Update the kyn_dyn cache. + data_tf = data.update_kyn_dyn(model=model) + # Restore the input velocity representation. data_tf = data_tf.replace( velocity_representation=data.velocity_representation, validate=False diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 1c0a11078..bdfa4214f 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -242,7 +242,7 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: )(W_f_L, W_H_L) # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) return f_L @@ -450,7 +450,7 @@ def convert_using_link_frame( )(f_L, W_H_L) # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)) diff --git a/src/jaxsim/mujoco/utils.py b/src/jaxsim/mujoco/utils.py index 2afff1732..c80bfa79e 100644 --- a/src/jaxsim/mujoco/utils.py +++ b/src/jaxsim/mujoco/utils.py @@ -59,7 +59,7 @@ def mujoco_data_from_jaxsim( if jaxsim_model.floating_base(): # Set the model position. - model_helper.set_base_position(position=np.array(jaxsim_data.base_position())) + model_helper.set_base_position(position=np.array(jaxsim_data.base_position)) # Set the model orientation. model_helper.set_base_orientation( diff --git a/src/jaxsim/rbda/aba.py b/src/jaxsim/rbda/aba.py index b01f46698..8a0aa1cbb 100644 --- a/src/jaxsim/rbda/aba.py +++ b/src/jaxsim/rbda/aba.py @@ -21,6 +21,8 @@ def aba( joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute forward dynamics using the Articulated Body Algorithm (ABA). @@ -85,12 +87,10 @@ def aba( W_X_B = W_H_B.adjoint() B_X_W = W_H_B.inverse().adjoint() - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index 543be5328..dba3d98b9 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -1,6 +1,5 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp @@ -18,6 +17,8 @@ def collidable_points_pos_vel( base_linear_velocity: jtp.Vector, base_angular_velocity: jtp.Vector, joint_velocities: jtp.Vector, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Matrix]: """ @@ -54,7 +55,7 @@ def collidable_points_pos_vel( if len(indices_of_enabled_collidable_points) == 0: return jnp.array(0).astype(float), jnp.empty(0).astype(float) - W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( + _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, @@ -68,18 +69,10 @@ def collidable_points_pos_vel( # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) - - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 56b403fa7..4641fbd46 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -251,7 +251,7 @@ def link_forces_from_contact_forces( # Compute the link transforms. W_H_L = ( - js.model.forward_kinematics(model=model, data=data) + data.kyn_dyn.forward_kinematics if data.velocity_representation is not jaxsim.VelRepr.Inertial else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index c433fe23d..fc06eb39b 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -827,7 +827,7 @@ def integrate_data_with_average_contact_forces( """ s_t0 = data.joint_positions() - W_p_B_t0 = data.base_position() + W_p_B_t0 = data.base_position W_Q_B_t0 = data.base_orientation(dcm=False) ṡ_t0 = data.joint_velocities() @@ -925,6 +925,8 @@ def integrate_data_with_average_contact_forces( W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial ) + data_tf = data_tf.update_kyn_dyn(model=model) + return data_tf.replace( velocity_representation=data.velocity_representation, validate=False ) @@ -1005,7 +1007,7 @@ def step( # Compute the link transforms. W_H_L = ( - js.model.forward_kinematics(model=model, data=data) + data.kyn_dyn.forward_kinematics if data.velocity_representation is not jaxsim.VelRepr.Inertial else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index adb3506ae..6272be898 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -4,10 +4,14 @@ import jaxsim.api as js import jaxsim.typing as jtp -from . import utils - -def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix: +def crba( + model: js.model.JaxSimModel, + *, + joint_positions: jtp.Vector, + joint_transforms, + motion_subspaces, +) -> jtp.Matrix: """ Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA). @@ -19,10 +23,6 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat The free-floating mass matrix of the model in body-fixed representation. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) - # Get the 6D spatial inertia matrices of all links. Mc = js.model.link_spatial_inertia_matrices(model=model) @@ -30,12 +30,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index d11e9b45d..371b25aaf 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -1,13 +1,10 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint -from . import utils - def forward_kinematics_model( model: js.model.JaxSimModel, @@ -15,6 +12,7 @@ def forward_kinematics_model( base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, + joint_transforms, ) -> jtp.Array: """ Compute the forward kinematics. @@ -29,29 +27,14 @@ def forward_kinematics_model( A 3D array containing the SE(3) transforms of all links belonging to the model. """ - W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - ) - # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) - - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi = joint_transforms # Allocate the buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index e8f44d088..3aedfa5bf 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -14,6 +14,8 @@ def jacobian( *, link_index: jtp.Int, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> jtp.Matrix: """ Compute the free-floating Jacobian of a link. @@ -27,20 +29,14 @@ def jacobian( The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) - # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) @@ -127,6 +123,8 @@ def jacobian_full_doubly_left( model: js.model.JaxSimModel, *, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Array]: r""" Compute the doubly-left full free-floating Jacobian of a model. @@ -144,10 +142,6 @@ def jacobian_full_doubly_left( The doubly-left full free-floating Jacobian of a model. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) - # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array @@ -155,9 +149,7 @@ def jacobian_full_doubly_left( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms base -> link. B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/rnea.py b/src/jaxsim/rbda/rnea.py index 025d85a62..8f3d18036 100644 --- a/src/jaxsim/rbda/rnea.py +++ b/src/jaxsim/rbda/rnea.py @@ -23,6 +23,8 @@ def rnea( joint_accelerations: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA). @@ -88,12 +90,10 @@ def rnea( W_X_B = W_H_B.adjoint() B_X_W = W_H_B.inverse().adjoint() - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 2c7d55aa8..2e4b4eb9b 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -20,9 +20,9 @@ dict[Hashable, TypeVar("PyTree")] | list[TypeVar("PyTree")] | tuple[TypeVar("PyTree")] - | None | jax.Array | Any + | None ) # ======================= diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index 4a0882737..1ec73cef3 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -16,56 +16,59 @@ def test_contact_kinematics( model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) - data = js.data.random_model_data( - model=model, - key=subkey, - velocity_representation=velocity_representation, - ) - - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - - # ===== - # Tests - # ===== - - # Compute the pose of the implicit contact frame associated to the collidable points - # and the transforms of all links. - W_H_C = js.contact.transforms(model=model, data=data) - W_H_L = js.model.forward_kinematics(model=model, data=data) - - # Check that the orientation of the implicit contact frame matches with the - # orientation of the link to which the contact point is attached. - for contact_idx, index_of_parent_link in enumerate( - parent_link_idx_of_enabled_collidable_points - ): - assert W_H_C[contact_idx, 0:3, 0:3] == pytest.approx( - W_H_L[index_of_parent_link][0:3, 0:3] + with jax.disable_jit(): + data = js.data.random_model_data( + model=model, + key=subkey, + velocity_representation=velocity_representation, ) - # Check that the origin of the implicit contact frame is located over the - # collidable point. - W_p_C = js.contact.collidable_point_positions(model=model, data=data) - assert W_p_C == pytest.approx(W_H_C[:, 0:3, 3]) - - # Compute the velocity of the collidable point. - # This quantity always matches with the linear component of the mixed 6D velocity - # of the implicit frame associated to the collidable point. - W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) - # Compute the velocity of the collidable point using the contact Jacobian. - ν = data.generalized_velocity() - CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) - CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + + # ===== + # Tests + # ===== + + # Compute the pose of the implicit contact frame associated to the collidable points + # and the transforms of all links. + W_H_C = js.contact.transforms(model=model, data=data) + W_H_L = js.model.forward_kinematics(model=model, data=data) + + # Check that the orientation of the implicit contact frame matches with the + # orientation of the link to which the contact point is attached. + for contact_idx, index_of_parent_link in enumerate( + parent_link_idx_of_enabled_collidable_points + ): + assert W_H_C[contact_idx, 0:3, 0:3] == pytest.approx( + W_H_L[index_of_parent_link][0:3, 0:3] + ) + + # Check that the origin of the implicit contact frame is located over the + # collidable point. + W_p_C = js.contact.collidable_point_positions(model=model, data=data) + assert W_p_C == pytest.approx(W_H_C[:, 0:3, 3]) + + # Compute the velocity of the collidable point. + # This quantity always matches with the linear component of the mixed 6D velocity + # of the implicit frame associated to the collidable point. + W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + + # Compute the velocity of the collidable point using the contact Jacobian. + ν = data.generalized_velocity() + CW_J_WC = js.contact.jacobian( + model=model, data=data, output_vel_repr=VelRepr.Mixed + ) + CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] - # Compare the two velocities. - assert W_ṗ_C == pytest.approx(CW_vl_WC) + # Compare the two velocities. + assert W_ṗ_C == pytest.approx(CW_vl_WC) def test_contact_jacobian_derivative( @@ -132,7 +135,7 @@ def test_contact_jacobian_derivative( # Rebuild the JaxSim data. data_with_frames = js.data.JaxSimModelData.build( model=model_with_frames, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(), base_linear_velocity=data.base_velocity()[0:3], diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index d20dbc5ff..675ea07e0 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -237,6 +237,9 @@ def J(q, frame_idxs) -> jax.Array: data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7]) data_ad = data_ad.reset_joint_positions(positions=q[7:]) + # Update the kyn_dyn cache. + data_ad = data_ad.update_kyn_dyn(model=model) + O_J_ad_WF_I = jax.vmap( lambda model, data, frame_index: js.frame.jacobian( model=model, data=data, frame_index=frame_index @@ -249,7 +252,7 @@ def J(q, frame_idxs) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( [ - data.base_position(), + data.base_position, data.base_orientation(), data.joint_positions(), ] diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 7f89e0cc5..dff1c8319 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -332,6 +332,9 @@ def J(q) -> jax.Array: data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7]) data_ad = data_ad.reset_joint_positions(positions=q[7:]) + # Update the kyn_dyn cache. + data_ad = data_ad.update_kyn_dyn(model=model) + O_J_WL_I = js.model.generalized_free_floating_jacobian( model=model, data=data_ad ) @@ -341,7 +344,7 @@ def J(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position(), data.base_orientation(), data.joint_positions()] + [data.base_position, data.base_orientation(), data.joint_positions()] ) return q diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 6b51da58d..aa34ec860 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -110,7 +110,7 @@ def test_model_creation_and_reduction( # Build the data of the reduced model. data_reduced = js.data.JaxSimModelData.build( model=model_reduced, - base_position=data_full.base_position(), + base_position=data_full.base_position, base_quaternion=data_full.base_orientation(dcm=False), joint_positions=data_full.joint_positions( model=model_full, joint_names=model_reduced.joint_names() @@ -284,49 +284,50 @@ def test_model_rbda( model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) - data = js.data.random_model_data( - model=model, key=subkey, velocity_representation=velocity_representation - ) + with jax.disable_jit(): + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) - kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( - model=model, data=data - ) + kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) - # ===== - # Tests - # ===== + # ===== + # Tests + # ===== - # Support both fixed-base and floating-base models by slicing the first six rows. - sl = np.s_[0:] if model.floating_base() else np.s_[6:] + # Support both fixed-base and floating-base models by slicing the first six rows. + sl = np.s_[0:] if model.floating_base() else np.s_[6:] - # Mass matrix - M_idt = kin_dyn.mass_matrix() - M_js = js.model.free_floating_mass_matrix(model=model, data=data) - assert pytest.approx(M_idt[sl, sl]) == M_js[sl, sl] + # Mass matrix + M_idt = kin_dyn.mass_matrix() + M_js = js.model.free_floating_mass_matrix(model=model, data=data) + assert pytest.approx(M_idt[sl, sl]) == M_js[sl, sl] - # Gravity forces - g_idt = kin_dyn.gravity_forces() - g_js = js.model.free_floating_gravity_forces(model=model, data=data) - assert pytest.approx(g_idt[sl]) == g_js[sl] + # Gravity forces + g_idt = kin_dyn.gravity_forces() + g_js = js.model.free_floating_gravity_forces(model=model, data=data) + assert pytest.approx(g_idt[sl]) == g_js[sl] - # Bias forces - h_idt = kin_dyn.bias_forces() - h_js = js.model.free_floating_bias_forces(model=model, data=data) - assert pytest.approx(h_idt[sl]) == h_js[sl] + # Bias forces + h_idt = kin_dyn.bias_forces() + h_js = js.model.free_floating_bias_forces(model=model, data=data) + assert pytest.approx(h_idt[sl]) == h_js[sl] - # Forward kinematics - HH_js = js.model.forward_kinematics(model=model, data=data) - HH_idt = jnp.stack( - [kin_dyn.frame_transform(frame_name=name) for name in model.link_names()] - ) - assert pytest.approx(HH_idt) == HH_js + # Forward kinematics + HH_js = js.model.forward_kinematics(model=model, data=data) + HH_idt = jnp.stack( + [kin_dyn.frame_transform(frame_name=name) for name in model.link_names()] + ) + assert pytest.approx(HH_idt) == HH_js - # Bias accelerations - Jν_js = js.model.link_bias_accelerations(model=model, data=data) - Jν_idt = jnp.stack( - [kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()] - ) - assert pytest.approx(Jν_idt) == Jν_js + # Bias accelerations + Jν_js = js.model.link_bias_accelerations(model=model, data=data) + Jν_idt = jnp.stack( + [kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()] + ) + assert pytest.approx(Jν_idt) == Jν_js def test_model_jacobian( @@ -440,6 +441,9 @@ def M(q) -> jax.Array: data_ad = data_ad.reset_base_quaternion(base_quaternion=q[3:7]) data_ad = data_ad.reset_joint_positions(positions=q[7:]) + # Update the kyn_dyn cache. + data_ad = data_ad.update_kyn_dyn(model=model) + M = js.model.free_floating_mass_matrix(model=model, data=data_ad) return M @@ -447,7 +451,7 @@ def M(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position(), data.base_orientation(), data.joint_positions()] + [data.base_position, data.base_orientation(), data.joint_positions()] ) return q diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 477f6245d..8f7a8ba3c 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -75,11 +75,13 @@ def test_ad_aba( g = jaxsim.math.StandardGravity # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() ṡ = data.joint_velocities(model=model) + i_X_λ = data.kyn_dyn.joint_transforms + S = data.kyn_dyn.motion_subspaces # Inputs. W_f_L = references.link_forces(model=model) @@ -101,6 +103,8 @@ def test_ad_aba( joint_forces=τ, link_forces=W_f_L, standard_gravity=g, + joint_transforms=i_X_λ, + motion_subspaces=S, ) # Check derivatives against finite differences. @@ -129,11 +133,13 @@ def test_ad_rnea( g = jaxsim.math.StandardGravity # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() ṡ = data.joint_velocities(model=model) + i_X_λ = data.kyn_dyn.joint_transforms + S = data.kyn_dyn.motion_subspaces # Inputs. W_f_L = references.link_forces(model=model) @@ -160,12 +166,24 @@ def test_ad_rnea( joint_accelerations=s̈, link_forces=W_f_L, standard_gravity=g, + joint_transforms=i_X_λ, + motion_subspaces=S, ) # Check derivatives against finite differences. check_grads( f=rnea, - args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, W_f_L, g), + args=( + W_p_B, + W_Q_B, + s, + W_v_WB, + ṡ, + W_v̇_WB, + s̈, + W_f_L, + g, + ), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, @@ -186,13 +204,20 @@ def test_ad_crba( # State in VelRepr.Inertial representation. s = data.joint_positions(model=model) + i_X_λ = data.kyn_dyn.joint_transforms + S = data.kyn_dyn.motion_subspaces # ==== # Test # ==== # Get a closure exposing only the parameters to be differentiated. - crba = lambda s: jaxsim.rbda.crba(model=model, joint_positions=s) + crba = lambda s: jaxsim.rbda.crba( + model=model, + joint_positions=s, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) # Check derivatives against finite differences. check_grads( @@ -217,26 +242,28 @@ def test_ad_fk( ) # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) + i_X_λ = data.kyn_dyn.joint_transforms # ==== # Test # ==== # Get a closure exposing only the parameters to be differentiated. - fk = lambda W_p_B, W_Q_B, s: jaxsim.rbda.forward_kinematics_model( + fk = lambda W_p_B, W_Q_B: jaxsim.rbda.forward_kinematics_model( model=model, base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, + joint_transforms=i_X_λ, ) # Check derivatives against finite differences. check_grads( f=fk, - args=(W_p_B, W_Q_B, s), + args=(W_p_B, W_Q_B), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, @@ -257,6 +284,8 @@ def test_ad_jacobian( # State in VelRepr.Inertial representation. s = data.joint_positions(model=model) + i_X_λ = data.kyn_dyn.joint_transforms + S = data.kyn_dyn.motion_subspaces # ==== # Test @@ -269,7 +298,11 @@ def test_ad_jacobian( # We differentiate the jacobian of the last link, likely among those # farther from the base. jacobian = lambda s: jaxsim.rbda.jacobian( - model=model, joint_positions=s, link_index=link_indices[-1] + model=model, + joint_positions=s, + link_index=link_indices[-1], + joint_transforms=i_X_λ, + motion_subspaces=S, ) # Check derivatives against finite differences. @@ -344,7 +377,7 @@ def test_ad_integration( ) # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() @@ -389,6 +422,9 @@ def step( ), ) + # Update the kyn_dyn cache. + data_x0.update_kyn_dyn(model=model) + data_xf, _ = js.model.step( model=model, data=data_x0, @@ -396,7 +432,7 @@ def step( link_forces=W_f_L, ) - xf_W_p_B = data_xf.base_position() + xf_W_p_B = data_xf.base_position xf_W_Q_B = data_xf.base_orientation(dcm=False) xf_s = data_xf.joint_positions(model=model) xf_W_v_WB = data_xf.base_velocity() diff --git a/tests/test_pytree.py b/tests/test_pytree.py index c27179254..c3bf82576 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -51,7 +51,7 @@ def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData): # Return random elements from model and data, just to have something returned. return ( jnp.sum(model.kin_dyn_parameters.link_parameters.mass), - data.base_position(), + data.base_position, ) data1 = js.data.JaxSimModelData.build(model=model1) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 99f90d899..ae6b4f848 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -78,7 +78,7 @@ def test_box_with_external_forces( ) # Check that the box didn't move. - assert data.base_position() == pytest.approx(data0.base_position()) + assert data.base_position == pytest.approx(data0.base_position) assert data.base_orientation() == pytest.approx(data0.base_orientation()) @@ -158,8 +158,8 @@ def test_box_with_zero_gravity( ) # Check that the box moved as expected. - assert data.base_position() == pytest.approx( - data0.base_position() + assert data.base_position == pytest.approx( + data0.base_position + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, abs=1e-3, ) @@ -247,8 +247,8 @@ def test_simulation_with_soft_contacts( data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) - assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) - assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) def test_simulation_with_visco_elastic_contacts( @@ -287,8 +287,8 @@ def test_simulation_with_visco_elastic_contacts( data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) - assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) - assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) def test_simulation_with_rigid_contacts( @@ -339,8 +339,8 @@ def test_simulation_with_rigid_contacts( data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) - assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) - assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) def test_simulation_with_relaxed_rigid_contacts( @@ -393,10 +393,10 @@ def test_simulation_with_relaxed_rigid_contacts( data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) # With this contact model, we need to slightly increase the tolerances. - assert data_tf.base_position()[0:2] == pytest.approx( - data_t0.base_position()[0:2], abs=0.000_010 + assert data_tf.base_position[0:2] == pytest.approx( + data_t0.base_position[0:2], abs=0.000_010 ) - assert data_tf.base_position()[2] + max_penetration == pytest.approx( + assert data_tf.base_position[2] + max_penetration == pytest.approx( box_height / 2, abs=0.000_100 ) @@ -438,6 +438,9 @@ def test_joint_limits( # Test minimum joint position limits. data_t0 = data.reset_joint_positions(positions=position_limits_min - theta) + # Update the kyn_dyn cache. + data_t0 = data_t0.update_kyn_dyn(model=model) + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.005, tf=3.0) assert (