From 803d68bf92384d9389865402d6066ac1ea768e8b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 16 Dec 2024 14:18:17 +0100 Subject: [PATCH] Add function to safely compute the norm --- src/jaxsim/api/data.py | 12 ++++-------- src/jaxsim/math/__init__.py | 1 + src/jaxsim/math/quaternion.py | 11 ++++------- src/jaxsim/math/rotation.py | 17 ++++------------- src/jaxsim/math/utils.py | 19 +++++++++++++++++++ src/jaxsim/rbda/contacts/soft.py | 11 ++++------- src/jaxsim/terrain/terrain.py | 3 ++- 7 files changed, 38 insertions(+), 36 deletions(-) create mode 100644 src/jaxsim/math/utils.py diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index b880547b9..af3ea7ef2 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -382,9 +382,8 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix # we introduce a Baumgarte stabilization to let the quaternion converge to # a unit quaternion. In this case, it is not guaranteed that the quaternion # stored in the state is a unit quaternion. - W_Q_B = jnp.where( - jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B) - ) + norm = jaxsim.math.safe_norm(W_Q_B) + W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype( float @@ -611,11 +610,8 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self: W_Q_B = jnp.array(base_quaternion, dtype=float) - W_Q_B = jax.lax.select( - pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0), - on_true=W_Q_B, - on_false=W_Q_B / jnp.linalg.norm(W_Q_B), - ) + norm = jaxsim.math.safe_norm(W_Q_B) + W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return self.replace( validate=True, diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py index 2e7b9c352..008a94630 100644 --- a/src/jaxsim/math/__init__.py +++ b/src/jaxsim/math/__init__.py @@ -8,5 +8,6 @@ from .rotation import Rotation from .skew import Skew from .transform import Transform +from .utils import safe_norm from .joint_model import JointModel, supported_joint_motion # isort:skip diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index e9115cb26..4870f1aa0 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -4,6 +4,8 @@ import jaxsim.typing as jtp +from .utils import safe_norm + class Quaternion: @staticmethod @@ -111,18 +113,13 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix: operand=quaternion, ) - norm_ω = jax.lax.cond( - pred=ω.dot(ω) < (1e-6) ** 2, - true_fun=lambda _: 1e-6, - false_fun=lambda _: jnp.linalg.norm(ω), - operand=None, - ) + norm_ω = safe_norm(ω) qd = 0.5 * ( Q @ jnp.hstack( [ - K * norm_ω * (1 - jnp.linalg.norm(quaternion)), + K * norm_ω * (1 - safe_norm(quaternion)), ω, ] ) diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index f445e1d74..471f496b8 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -4,6 +4,7 @@ import jaxsim.typing as jtp from .skew import Skew +from .utils import safe_norm class Rotation: @@ -67,7 +68,7 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix: v = axis - theta = jnp.linalg.norm(v) + theta = safe_norm(v) s = jnp.sin(theta) c = jnp.cos(theta) @@ -81,19 +82,9 @@ def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix: return R.transpose() - # Use the double-where trick to prevent JAX problems when the - # jax.jit and jax.grad transforms are applied. return jnp.where( - jnp.linalg.norm(vector) > 0, - theta_is_not_zero( - axis=jnp.where( - jnp.linalg.norm(vector) > 0, - vector, - # The following line is a workaround to prevent division by 0. - # Considering the outer where, this branch is never executed. - jnp.ones(3), - ) - ), + jnp.allclose(vector, 0.0), # Return an identity rotation matrix when the input vector is zero. jnp.eye(3), + theta_is_not_zero(axis=vector), ) diff --git a/src/jaxsim/math/utils.py b/src/jaxsim/math/utils.py new file mode 100644 index 000000000..46a9f02e2 --- /dev/null +++ b/src/jaxsim/math/utils.py @@ -0,0 +1,19 @@ +import jax.numpy as jnp + +import jaxsim.typing as jtp + + +def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array: + """ + Provides a calculation for an array norm so that it is safe + to compute the gradient and the NaNs are handled. + + Args: + array: The array for which to compute the norm + axis: The axis for which to compute the norm + """ + is_zero = jnp.allclose(array, 0.0) + array = jnp.where(is_zero, jnp.ones_like(array), array) + + norm = jnp.linalg.norm(array, axis=axis) + return jnp.where(is_zero, 0.0, norm) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index be27fc9da..8d4c0d545 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -309,19 +309,16 @@ def hunt_crossley_contact_model( # Compute the direction of the tangential force. # To prevent dividing by zero, we use a switch statement. - # The ε, instead, is needed to make AD happy. - f_tangential_direction = jnp.where( - f_tangential.dot(f_tangential) != 0, - f_tangential / jnp.linalg.norm(f_tangential + ε), - jnp.zeros(3), + norm = jaxsim.math.safe_norm(f_tangential) + f_tangential_direction = f_tangential / ( + norm + jnp.finfo(float).eps * (norm == 0) ) # Project the tangential force to the friction cone if slipping. f_tangential = jnp.where( sticking, f_tangential, - jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε)) - * f_tangential_direction, + jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction, ) # Set the tangential force to zero if there is no contact. diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index 9b2316425..f6b4ddcc2 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -7,6 +7,7 @@ import jax_dataclasses import numpy as np +import jaxsim.math import jaxsim.typing as jtp from jaxsim import exceptions @@ -41,7 +42,7 @@ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0] ) - return n / jnp.linalg.norm(n) + return n / jaxsim.math.safe_norm(n) @jax_dataclasses.pytree_dataclass