Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe norm function and refactor usages #319

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,8 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix
# we introduce a Baumgarte stabilization to let the quaternion converge to
# a unit quaternion. In this case, it is not guaranteed that the quaternion
# stored in the state is a unit quaternion.
W_Q_B = jnp.where(
jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B)
)
norm = jaxsim.math.safe_norm(W_Q_B)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))

return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
float
Expand Down Expand Up @@ -611,11 +610,8 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:

W_Q_B = jnp.array(base_quaternion, dtype=float)

W_Q_B = jax.lax.select(
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
on_true=W_Q_B,
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
)
norm = jaxsim.math.safe_norm(W_Q_B)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))

return self.replace(
validate=True,
Expand Down
1 change: 1 addition & 0 deletions src/jaxsim/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from .rotation import Rotation
from .skew import Skew
from .transform import Transform
from .utils import safe_norm

from .joint_model import JointModel, supported_joint_motion # isort:skip
11 changes: 4 additions & 7 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import jaxsim.typing as jtp

from .utils import safe_norm


class Quaternion:
@staticmethod
Expand Down Expand Up @@ -111,18 +113,13 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix:
operand=quaternion,
)

norm_ω = jax.lax.cond(
pred=ω.dot(ω) < (1e-6) ** 2,
true_fun=lambda _: 1e-6,
false_fun=lambda _: jnp.linalg.norm(ω),
operand=None,
)
norm_ω = safe_norm(ω)

qd = 0.5 * (
Q
@ jnp.hstack(
[
K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
K * norm_ω * (1 - safe_norm(quaternion)),
ω,
]
)
Expand Down
17 changes: 4 additions & 13 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jaxsim.typing as jtp

from .skew import Skew
from .utils import safe_norm


class Rotation:
Expand Down Expand Up @@ -67,7 +68,7 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

v = axis
theta = jnp.linalg.norm(v)
theta = safe_norm(v)

s = jnp.sin(theta)
c = jnp.cos(theta)
Expand All @@ -81,19 +82,9 @@ def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

return R.transpose()

# Use the double-where trick to prevent JAX problems when the
# jax.jit and jax.grad transforms are applied.
return jnp.where(
jnp.linalg.norm(vector) > 0,
theta_is_not_zero(
axis=jnp.where(
jnp.linalg.norm(vector) > 0,
vector,
# The following line is a workaround to prevent division by 0.
# Considering the outer where, this branch is never executed.
jnp.ones(3),
)
),
jnp.allclose(vector, 0.0),
# Return an identity rotation matrix when the input vector is zero.
jnp.eye(3),
theta_is_not_zero(axis=vector),
)
22 changes: 22 additions & 0 deletions src/jaxsim/math/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import jax.numpy as jnp

import jaxsim.typing as jtp


def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
"""
Provides a calculation for an array norm so that it is safe
to compute the gradient and the NaNs are handled.

Args:
array: The array for which to compute the norm.
axis: The axis for which to compute the norm.
"""

is_zero = jnp.allclose(array, 0.0)

array = jnp.where(is_zero, jnp.ones_like(array), array)
xela-95 marked this conversation as resolved.
Show resolved Hide resolved

norm = jnp.linalg.norm(array, axis=axis)

return jnp.where(is_zero, 0.0, norm)
11 changes: 4 additions & 7 deletions src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,16 @@ def hunt_crossley_contact_model(

# Compute the direction of the tangential force.
# To prevent dividing by zero, we use a switch statement.
# The ε, instead, is needed to make AD happy.
f_tangential_direction = jnp.where(
f_tangential.dot(f_tangential) != 0,
f_tangential / jnp.linalg.norm(f_tangential + ε),
jnp.zeros(3),
norm = jaxsim.math.safe_norm(f_tangential)
f_tangential_direction = f_tangential / (
norm + jnp.finfo(float).eps * (norm == 0)
)

# Project the tangential force to the friction cone if slipping.
f_tangential = jnp.where(
sticking,
f_tangential,
jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε))
* f_tangential_direction,
jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
)

# Set the tangential force to zero if there is no contact.
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/terrain/terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax_dataclasses
import numpy as np

import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions

Expand Down Expand Up @@ -41,7 +42,7 @@ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
[(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
)

return n / jnp.linalg.norm(n)
return n / jaxsim.math.safe_norm(n)


@jax_dataclasses.pytree_dataclass
Expand Down