Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor velocity representations as integers #160

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a8d0c5b
Avoid to use `enum` in `VelRepr`
flferretti May 23, 2024
acc70ab
Make `velocity_representation` non-static in `api.data`
flferretti May 23, 2024
120b474
Update documentation for `VelRepr`
flferretti May 23, 2024
d346f2e
Make `output_vel_repr` non-static in `api.frame`
flferretti May 23, 2024
e299b59
Make `output_vel_repr` non-static in `api.link`
flferretti May 23, 2024
b4b359d
Use `jax.lax.switch` inside `api.frame` for `VelRepr`
flferretti Jun 4, 2024
9cc3d3c
Use `jax.lax.switch` inside `api.link` for `VelRepr`
flferretti Jun 4, 2024
c545bd7
Use `jax.lax.switch` inside `api.com` for `VelRepr`
flferretti Jun 4, 2024
d96d98d
Make `VelRepr` non-static in `ModelDatawithVelocityRepresentation`
flferretti Jun 4, 2024
41a3f89
Make `output_vel_repr` non-static in `api.common`
flferretti Jun 4, 2024
baf8f6f
Make `output_vel_repr` non-static in `api.model`
flferretti Jun 4, 2024
8230b3e
Use `jax.lax.switch` inside `api.model` for `VelRepr`
flferretti Jun 4, 2024
1c32639
Update error messages in `api.references`
flferretti Jun 4, 2024
0a2d9c3
Use `jax.lax.switch` in `api.contact` for `VelRepr`
flferretti Jun 4, 2024
c3940f3
Fix dimensions in `api.com.bias_acceleration`
flferretti Jun 4, 2024
6c59458
Use equality instead of identity operator for comparing `VelRepr`
flferretti Jun 4, 2024
6163ad9
Update type-hints to `int` for `VelRepr`
flferretti Jun 4, 2024
42cd69e
Use `jax.lax.cond` to check the velocity representation
flferretti Jun 5, 2024
12a6f8d
Use `jax.pure_callback` to throw errors
flferretti Jun 5, 2024
0bc1d76
Add additional type checks
flferretti Jun 6, 2024
68e2f8e
Fix base acceleration transform to `VelRepr.Mixed`
flferretti May 24, 2024
0da3a62
Define `jaxsim.typing.VelRepr` as `Int`
flferretti Jun 14, 2024
06dd88e
Update JIT checks in `JaxSimModelReferences`
flferretti Jun 14, 2024
7543660
Update `result_shape_dtypes` in `JaxSimModelReferences`
flferretti Jun 17, 2024
0aae8fc
Use `jax.lax.switch` for Jacobian derivative computation
flferretti Jul 31, 2024
f307618
Adjust output names and return type hints
flferretti Aug 1, 2024
d580622
Use `jaxsim.exceptions` module to handle dynamic checks
flferretti Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ References
Common
~~~~~~

.. autoflag:: jaxsim.api.common.VelRepr
.. autoclass:: jaxsim.api.common.VelRepr
:members:

.. autoclass:: jaxsim.api.common.ModelDataWithVelocityRepresentation
Expand Down
248 changes: 138 additions & 110 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Cross, Transform

from .common import VelRepr

Expand All @@ -27,7 +27,7 @@ def com_position(

W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_B = data.base_transform()
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
B_H_W = Transform.inverse(transform=W_H_B)

def B_p̃_LCoM(i) -> jtp.Vector:
m = js.link.mass(model=model, link_index=i)
Expand Down Expand Up @@ -131,20 +131,29 @@ def centroidal_momentum_jacobian(
)

W_H_B = data.base_transform()
B_H_W = jaxsim.math.Transform.inverse(W_H_B)
B_H_W = Transform.inverse(W_H_B)

W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
def to_inertial_and_mixed():
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
return W_H_GW

def to_body():
W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
return W_H_GB

W_H_G = jax.lax.switch(
index=data.velocity_representation,
branches=(
to_body, # VelRepr.Body
to_inertial_and_mixed, # VelRepr.Mixed
to_inertial_and_mixed, # VelRepr.Inertial
),
)

# Compute the transform for 6D forces.
G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T
G_Xf_B = Adjoint.from_transform(transform=B_H_W @ W_H_G).T

return G_Xf_B @ B_Jh

Expand All @@ -170,17 +179,26 @@ def locked_centroidal_spatial_inertia(
W_H_B = data.base_transform()
W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
def to_inertial_or_mixed() -> jtp.Matrix:
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
return W_H_GW

def to_body() -> jtp.Matrix:
W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
return W_H_GB

W_H_G = jax.lax.switch(
index=data.velocity_representation,
branches=(
to_body, # VelRepr.Body
to_inertial_or_mixed, # VelRepr.Mixed
to_inertial_or_mixed, # VelRepr.Inertial
),
)

B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G
B_H_G = Transform.inverse(W_H_B) @ W_H_G

B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
B_Xv_G = Adjoint.from_transform(transform=B_H_G)
G_Xf_B = B_Xv_G.transpose()

return G_Xf_B @ B_Mbb_B @ B_Xv_G
Expand Down Expand Up @@ -275,80 +293,86 @@ def other_representation_to_body(
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
"""

L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C)
C_X_L = jaxsim.math.Adjoint.inverse(L_X_C)
L_X_C = Adjoint.from_transform(transform=L_H_C)
C_X_L = Adjoint.inverse(L_X_C)

L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)
L_v̇_WL = L_X_C @ (C_v̇_WL + Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)
return L_v̇_WL

def to_body() -> jtp.Vector:
L_a_bias_WL = v̇_bias_WL

return L_a_bias_WL

def to_inertial() -> jtp.Vector:

W_v̇_bias_WL = v̇_bias_WL
W_v_WW = jnp.zeros(6)

L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))(W_H_L)

L_v_LW = jax.vmap(
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
)(jnp.arange(model.number_of_links()))

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=W_v̇_bias_WL[i],
C_v_WC=W_v_WW,
L_H_C=L_H_W[i],
L_v_LC=L_v_LW[i],
)
)(jnp.arange(model.number_of_links()))

return L_a_bias_WL

def to_mixed() -> jtp.Vector:

LW_v̇_bias_WL = v̇_bias_WL

LW_v_W_LW = jax.vmap(
lambda i: js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
)
.at[3:6]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_H_LW = jax.vmap(
lambda W_H_L: Transform.inverse(W_H_L.at[0:3, 3].set(jnp.zeros(3)))
)(W_H_L)

L_v_L_LW = jax.vmap(
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
.at[0:3]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=LW_v̇_bias_WL[i],
C_v_WC=LW_v_W_LW[i],
L_H_C=L_H_LW[i],
L_v_LC=L_v_L_LW[i],
)
)(jnp.arange(model.number_of_links()))

return L_a_bias_WL

# We need here to get the body-fixed bias acceleration of the links.
# Since it's computed in the active representation, we need to convert it to body.
match data.velocity_representation:

case VelRepr.Body:
L_a_bias_WL = v̇_bias_WL

case VelRepr.Inertial:

C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841

L_H_C = L_H_W = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
)(W_H_L)

L_v_LC = L_v_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
)(jnp.arange(model.number_of_links()))

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC,
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))

case VelRepr.Mixed:

C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841

C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
lambda i: js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
)
.at[3:6]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_H_C = L_H_LW = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(
W_H_L.at[0:3, 3].set(jnp.zeros(3))
)
)(W_H_L)

L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
.at[0:3]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC[i],
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))

case _:
raise ValueError(data.velocity_representation)
L_a_bias_WL = jax.lax.switch(
index=data.velocity_representation,
branches=(
to_body, # VelRepr.Body
to_mixed, # VelRepr.Mixed
to_inertial, # VelRepr.Inertial
),
)

# Compute the bias of the 6D momentum derivative.
def bias_momentum_derivative_term(
Expand All @@ -364,13 +388,11 @@ def bias_momentum_derivative_term(
)

# Compute the world-to-link transformations for 6D forces.
W_Xf_L = jaxsim.math.Adjoint.from_transform(
transform=W_H_L[link_index], inverse=True
).T
W_Xf_L = Adjoint.from_transform(transform=W_H_L[link_index], inverse=True).T

# Compute the contribution of the link to the bias acceleration of the CoM.
W_ḣ_bias_link_contribution = W_Xf_L @ (
L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
L_M_L @ L_a_bias_WL + Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
)

return W_ḣ_bias_link_contribution
Expand All @@ -386,30 +408,36 @@ def bias_momentum_derivative_term(
# Compute the position of the CoM.
W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:

def to_inertial_or_mixed() -> jtp.Vector:
# G := G[W] = (W_p_CoM, [W])
case VelRepr.Inertial | VelRepr.Mixed:

W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_Xf_W = Adjoint.from_transform(W_H_GW).T

GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m
GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m

return GW_v̇l_com_bias
return GW_v̇l_com_bias

def to_body() -> jtp.Vector:
# G := G[B] = (W_p_CoM, [B])
case VelRepr.Body:

GB_Xf_W = jaxsim.math.Adjoint.from_transform(
transform=data.base_transform().at[0:3].set(W_p_CoM)
).T
GB_Xf_W = Adjoint.from_transform(
transform=data.base_transform().at[0:3, 3].set(W_p_CoM)
).T

GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m
GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m

return GB_v̇l_com_bias
return GB_v̇l_com_bias

GB_v̇l_com_bias = jax.lax.switch(
index=data.velocity_representation,
branches=(
to_body, # VelRepr.Body
to_inertial_or_mixed, # VelRepr.Mixed
to_inertial_or_mixed, # VelRepr.Inertial
),
)

case _:
raise ValueError(data.velocity_representation)
return GB_v̇l_com_bias
Loading
Loading