diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 3c16f3df6..a9d9d8aa2 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -187,48 +187,36 @@ def link_forces( # serialization. if model is None: - def inertial() -> jtp.Array: - if link_names is not None: - raise ValueError("Link names cannot be provided without a model") - - return self.input.physics_model.f_ext - - return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=inertial, - false_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=self.input.physics_model.f_ext, - ), + exceptions.raise_value_error_if( + condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + msg="Missing model to use a representation different from `VelRepr.Inertial`", ) + if link_names is not None: + raise ValueError("Link names cannot be provided without a model") + + return self.input.physics_model.f_ext + # If we have the model, we can extract the link names, if not provided. link_names = link_names if link_names is not None else model.link_names() link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) - def check_not_inertial() -> None: - if data is None: - raise ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - - if not_tracing(self.input.physics_model.f_ext) and not data.valid( - model=model - ): - raise ValueError("The provided data is not valid for the model") - # If not inertial-fixed representation, we need the model data. - jax.lax.cond( - pred=(self.velocity_representation != VelRepr.Inertial), - true_fun=lambda: jax.pure_callback( - callback=check_not_inertial, - result_shape_dtypes=None, + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, ), - false_fun=lambda: None, + msg="Missing model data to use a representation different from `VelRepr.Inertial`", + ) + + # The f_L output is either L_f_L or LW_f_L, depending on the representation. + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, + ), + msg="Missing model data to use a representation different from `VelRepr.Inertial`", ) def not_inertial(velocity_representation: jtp.VelRepr) -> jtp.Matrix: @@ -245,22 +233,8 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: ) )(W_f_L, W_H_L) - # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = jax.lax.cond( - pred=(data is None), - true_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=jnp.empty( - shape=(model.number_of_links(), 4, 4) - ), - ), - false_fun=lambda: js.model.forward_kinematics( - model=model, data=data or JaxSimModelData.zero(model=model) - ), + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) ) f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) @@ -268,7 +242,7 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: # In inertial-fixed representation, we already have the link forces. return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), + pred=jnp.equal(self.velocity_representation, VelRepr.Inertial), true_fun=lambda _: W_f_L[link_idxs, :], false_fun=not_inertial, operand=self.velocity_representation, @@ -418,35 +392,28 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # In this case, we allow only to set the inertial 6D forces to all links # using the implicit link serialization. - if model is None: - - def inertial() -> JaxSimModelReferences: - if link_names is not None: - raise ValueError("Link names cannot be provided without a model") + exceptions.raise_value_error_if( + condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial) + & (model is None), + msg="Missing model to use a representation different from `VelRepr.Inertial`", + ) - W_f_L = f_L + exceptions.raise_value_error_if( + condition=jnp.logical_and(link_names is not None, model is None), + msg="Link names cannot be provided without a model", + ) - W_f0_L = ( - jnp.zeros_like(W_f_L) - if not additive - else self.input.physics_model.f_ext - ) + if model is None: + W_f_L = f_L - return replace(forces=W_f0_L + W_f_L) - - return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=inertial, - false_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=self, - ), + W_f0_L = ( + jnp.zeros_like(W_f_L) + if not additive + else self.input.physics_model.f_ext ) + return replace(forces=W_f0_L + W_f_L) + # If we have the model, we can extract the link names if not provided. link_names = link_names if link_names is not None else model.link_names() @@ -467,6 +434,14 @@ def inertial() -> JaxSimModelReferences: else self.input.physics_model.f_ext[link_idxs, :] ) + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, + ), + msg="Missing model data to use a representation different from `VelRepr.Inertial`", + ) + # If inertial-fixed representation, we can directly store the link forces. def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: W_f_L = f_L @@ -492,28 +467,11 @@ def convert_using_link_frame( ) )(f_L, W_H_L) - # If not inertial-fixed representation, we need the model data. - W_H_L = jax.lax.cond( - pred=(data is None), - true_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=jnp.empty( - shape=(model.number_of_links(), 4, 4) - ), - ), - false_fun=lambda: js.model.forward_kinematics( - model=model, data=data or JaxSimModelData.zero(model=model) - ), + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) ) # The f_L input is either L_f_L or LW_f_L, depending on the representation. - # W_H_L = js.model.forward_kinematics( - # model=model, data=data or JaxSimModelData.zero(model=model) - # ) W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace( @@ -523,7 +481,7 @@ def convert_using_link_frame( ) return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), + pred=jnp.equal(self.velocity_representation, VelRepr.Inertial), true_fun=inertial, false_fun=not_inertial, operand=self.velocity_representation, diff --git a/tests/test_api_references.py b/tests/test_api_references.py index b656d2719..7bc35823e 100644 --- a/tests/test_api_references.py +++ b/tests/test_api_references.py @@ -109,7 +109,7 @@ def test_raise_errors_apply_link_forces( # `model` is None with pytest.raises( - ValueError, + XlaRuntimeError, match="Link names cannot be provided without a model", ): references_inertial.apply_link_forces( @@ -124,15 +124,6 @@ def test_raise_errors_apply_link_forces( model=model, data=data, velocity_representation=VelRepr.Body, key=subkey2 ) - # `model` is None - with pytest.raises( - ValueError, - match="Link names cannot be provided without a model", - ): - references_body.apply_link_forces( - forces=jnp.zeros(6), model=None, data=None, link_names=model.link_names() - ) - # `model` is None with pytest.raises( XlaRuntimeError,