From a2dff84de78fd9957be543e3a4e731c2b707dfdf Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 16 Dec 2024 15:11:03 +0100 Subject: [PATCH] [WIP] Save some kindyn computation in `JaxSimModelData` --- src/jaxsim/api/contact.py | 3 --- src/jaxsim/api/link.py | 2 +- src/jaxsim/api/model.py | 8 +++++--- src/jaxsim/integrators/common.py | 8 ++++---- src/jaxsim/rbda/contacts/relaxed_rigid.py | 2 +- tests/test_simulations.py | 12 ------------ 6 files changed, 11 insertions(+), 24 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 11ec6dc76..bb482fe2e 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -666,13 +666,11 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: W_J_WL_W = js.model.generalized_free_floating_jacobian( model=model, data=data, - output_vel_repr=VelRepr.Inertial, ) # Compute the Jacobian derivative of the parent link in inertial representation. W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( model=model, data=data, - output_vel_repr=VelRepr.Inertial, ) # Get the Jacobian of the enabled collidable points in the mixed representation. @@ -680,7 +678,6 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: CW_J_WC_BW = jacobian( model=model, data=data, - output_vel_repr=VelRepr.Mixed, ) def compute_O_J̇_WC_I( diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 2719da5d1..ca7a2ded3 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -423,7 +423,7 @@ def jacobian_derivative( ) O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative( - model=model, data=data, output_vel_repr=output_vel_repr + model=model, data=data )[link_index] return O_J̇_WL_I diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index bf8d1cdc1..b86a221f1 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -271,9 +271,9 @@ def build( integrator_cls = integrator integrator = integrator_cls.build( - dynamics=js.ode.wrap_system_dynamics_for_integration( - system_dynamics=js.ode.system_dynamics - ) + # dynamics=js.ode.wrap_system_dynamics_for_integration( + # system_dynamics=js.ode.system_dynamics + # ) ) case _: @@ -2218,6 +2218,8 @@ def forward( data_tf = data.replace(state=state_tf) + data_tf = data_tf.update_kyn_dyn(model=model) + return data_tf diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 912858fd9..482b81306 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -342,11 +342,11 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: ti = t0 + c[i] * Δt # Evaluate the dynamics. - ki, aux_dict = f(x=xi, t=ti) - return ki, aux_dict + ki = f(x=xi, t=ti) + return ki # This selector enables FSAL property in the first iteration (i=0). - ki, aux_dict = jax.lax.cond( + ki = jax.lax.cond( pred=jnp.logical_and(i == 0, self.has_fsal), true_fun=lambda: x0, false_fun=compute_ki, @@ -357,7 +357,7 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: K = jax.tree.map(op, K, ki) carry = K - return carry, aux_dict + return carry, None # Compute the state derivatives kᵢ. K, _ = jax.lax.scan( diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index fd84c6941..642a198e2 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -320,7 +320,7 @@ def compute_contact_forces( ) ) - M = js.model.free_floating_mass_matrix(model=model, data=data) + M = data.kyn_dyn.mass_matrix Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 12026c6a7..2b9aa9ede 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -201,19 +201,8 @@ def run_simulation( return data -@pytest.mark.parametrize( - "integrator", - [ - jaxsim.integrators.fixed_step.ForwardEuler, - jaxsim.integrators.fixed_step.ForwardEulerSO3, - jaxsim.integrators.fixed_step.RungeKutta4, - jaxsim.integrators.fixed_step.RungeKutta4SO3, - jaxsim.integrators.variable_step.BogackiShampineSO3, - ], -) def test_simulation_with_soft_contacts( jaxsim_model_box: js.model.JaxSimModel, - integrator: jaxsim.integrators.Integrator, ): model = jaxsim_model_box @@ -229,7 +218,6 @@ def test_simulation_with_soft_contacts( model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) - model.integrator = integrator.build() assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4