Skip to content

Commit

Permalink
Use jaxsim.exceptions module to handle dynamic checks
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Aug 22, 2024
1 parent 9c8162f commit dc175c9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 105 deletions.
148 changes: 53 additions & 95 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,48 +187,36 @@ def link_forces(
# serialization.
if model is None:

def inertial() -> jtp.Array:
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")

return self.input.physics_model.f_ext

return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=inertial,
false_fun=lambda: jax.pure_callback(
callback=lambda: (_ for _ in ()).throw(
ValueError(
"Missing model to use a representation different from `VelRepr.Inertial`"
)
),
result_shape_dtypes=self.input.physics_model.f_ext,
),
exceptions.raise_value_error_if(
condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial),
msg="Missing model to use a representation different from `VelRepr.Inertial`",
)

if link_names is not None:
raise ValueError("Link names cannot be provided without a model")

return self.input.physics_model.f_ext

# If we have the model, we can extract the link names, if not provided.
link_names = link_names if link_names is not None else model.link_names()
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)

def check_not_inertial() -> None:
if data is None:
raise ValueError(
"Missing model data to use a representation different from `VelRepr.Inertial`"
)

if not_tracing(self.input.physics_model.f_ext) and not data.valid(
model=model
):
raise ValueError("The provided data is not valid for the model")

# If not inertial-fixed representation, we need the model data.
jax.lax.cond(
pred=(self.velocity_representation != VelRepr.Inertial),
true_fun=lambda: jax.pure_callback(
callback=check_not_inertial,
result_shape_dtypes=None,
exceptions.raise_value_error_if(
condition=jnp.logical_and(
jnp.not_equal(self.velocity_representation, VelRepr.Inertial),
data is None,
),
false_fun=lambda: None,
msg="Missing model data to use a representation different from `VelRepr.Inertial`",
)

# The f_L output is either L_f_L or LW_f_L, depending on the representation.
exceptions.raise_value_error_if(
condition=jnp.logical_and(
jnp.not_equal(self.velocity_representation, VelRepr.Inertial),
data is None,
),
msg="Missing model data to use a representation different from `VelRepr.Inertial`",
)

def not_inertial(velocity_representation: jtp.VelRepr) -> jtp.Matrix:
Expand All @@ -245,30 +233,16 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
)
)(W_f_L, W_H_L)

# The f_L output is either L_f_L or LW_f_L, depending on the representation.
W_H_L = jax.lax.cond(
pred=(data is None),
true_fun=lambda: jax.pure_callback(
callback=lambda: (_ for _ in ()).throw(
ValueError(
"Missing model data to use a representation different from `VelRepr.Inertial`"
)
),
result_shape_dtypes=jnp.empty(
shape=(model.number_of_links(), 4, 4)
),
),
false_fun=lambda: js.model.forward_kinematics(
model=model, data=data or JaxSimModelData.zero(model=model)
),
W_H_L = js.model.forward_kinematics(
model=model, data=data or JaxSimModelData.zero(model=model)
)
f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])

return f_L

# In inertial-fixed representation, we already have the link forces.
return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
pred=jnp.equal(self.velocity_representation, VelRepr.Inertial),
true_fun=lambda _: W_f_L[link_idxs, :],
false_fun=not_inertial,
operand=self.velocity_representation,
Expand Down Expand Up @@ -418,35 +392,28 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:

# In this case, we allow only to set the inertial 6D forces to all links
# using the implicit link serialization.
if model is None:

def inertial() -> JaxSimModelReferences:
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")
exceptions.raise_value_error_if(
condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial)
& (model is None),
msg="Missing model to use a representation different from `VelRepr.Inertial`",
)

W_f_L = f_L
exceptions.raise_value_error_if(
condition=jnp.logical_and(link_names is not None, model is None),
msg="Link names cannot be provided without a model",
)

W_f0_L = (
jnp.zeros_like(W_f_L)
if not additive
else self.input.physics_model.f_ext
)
if model is None:
W_f_L = f_L

return replace(forces=W_f0_L + W_f_L)

return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=inertial,
false_fun=lambda: jax.pure_callback(
callback=lambda: (_ for _ in ()).throw(
ValueError(
"Missing model to use a representation different from `VelRepr.Inertial`"
)
),
result_shape_dtypes=self,
),
W_f0_L = (
jnp.zeros_like(W_f_L)
if not additive
else self.input.physics_model.f_ext
)

return replace(forces=W_f0_L + W_f_L)

# If we have the model, we can extract the link names if not provided.
link_names = link_names if link_names is not None else model.link_names()

Expand All @@ -467,6 +434,14 @@ def inertial() -> JaxSimModelReferences:
else self.input.physics_model.f_ext[link_idxs, :]
)

exceptions.raise_value_error_if(
condition=jnp.logical_and(
jnp.not_equal(self.velocity_representation, VelRepr.Inertial),
data is None,
),
msg="Missing model data to use a representation different from `VelRepr.Inertial`",
)

# If inertial-fixed representation, we can directly store the link forces.
def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences:
W_f_L = f_L
Expand All @@ -492,28 +467,11 @@ def convert_using_link_frame(
)
)(f_L, W_H_L)

# If not inertial-fixed representation, we need the model data.
W_H_L = jax.lax.cond(
pred=(data is None),
true_fun=lambda: jax.pure_callback(
callback=lambda: (_ for _ in ()).throw(
ValueError(
"Missing model data to use a representation different from `VelRepr.Inertial`"
)
),
result_shape_dtypes=jnp.empty(
shape=(model.number_of_links(), 4, 4)
),
),
false_fun=lambda: js.model.forward_kinematics(
model=model, data=data or JaxSimModelData.zero(model=model)
),
W_H_L = js.model.forward_kinematics(
model=model, data=data or JaxSimModelData.zero(model=model)
)

# The f_L input is either L_f_L or LW_f_L, depending on the representation.
# W_H_L = js.model.forward_kinematics(
# model=model, data=data or JaxSimModelData.zero(model=model)
# )
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])

return replace(
Expand All @@ -523,7 +481,7 @@ def convert_using_link_frame(
)

return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
pred=jnp.equal(self.velocity_representation, VelRepr.Inertial),
true_fun=inertial,
false_fun=not_inertial,
operand=self.velocity_representation,
Expand Down
11 changes: 1 addition & 10 deletions tests/test_api_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_raise_errors_apply_link_forces(

# `model` is None
with pytest.raises(
ValueError,
XlaRuntimeError,
match="Link names cannot be provided without a model",
):
references_inertial.apply_link_forces(
Expand All @@ -124,15 +124,6 @@ def test_raise_errors_apply_link_forces(
model=model, data=data, velocity_representation=VelRepr.Body, key=subkey2
)

# `model` is None
with pytest.raises(
ValueError,
match="Link names cannot be provided without a model",
):
references_body.apply_link_forces(
forces=jnp.zeros(6), model=None, data=None, link_names=model.link_names()
)

# `model` is None
with pytest.raises(
XlaRuntimeError,
Expand Down

0 comments on commit dc175c9

Please sign in to comment.