Skip to content

Commit

Permalink
[WIP] Save some kindyn computation in JaxSimModelData
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 17, 2024
1 parent 30472ed commit 9a1cc27
Show file tree
Hide file tree
Showing 24 changed files with 198 additions and 120 deletions.
4 changes: 2 additions & 2 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def com_position(

m = js.model.total_mass(model=model)

W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.kyn_dyn.forward_kinematics
W_H_B = data.base_transform()
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)

Expand Down Expand Up @@ -269,7 +269,7 @@ def bias_acceleration(
"""

# Compute the pose of all links with forward kinematics.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.kyn_dyn.forward_kinematics

# Compute the bias acceleration of all links by zeroing the generalized velocity
# in the active representation.
Expand Down
24 changes: 8 additions & 16 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def collidable_point_kinematics(

W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
base_position=data.base_position(),
base_position=data.base_position,
base_quaternion=data.base_orientation(dcm=False),
joint_positions=data.joint_positions(model=model),
base_linear_velocity=data.base_velocity()[0:3],
base_angular_velocity=data.base_velocity()[3:6],
joint_velocities=data.joint_velocities(model=model),
joint_transforms=data.kyn_dyn.joint_transforms,
motion_subspaces=data.kyn_dyn.motion_subspaces,
)

return W_p_Ci, W_ṗ_Ci
Expand Down Expand Up @@ -460,7 +462,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
)[indices_of_enabled_collidable_points]

# Get the transforms of the parent link of all collidable points.
W_H_L = js.model.forward_kinematics(model=model, data=data)[
W_H_L = data.kyn_dyn.forward_kinematics[
parent_link_idx_of_enabled_collidable_points
]

Expand Down Expand Up @@ -518,9 +520,7 @@ def jacobian(
)[indices_of_enabled_collidable_points]

# Compute the Jacobians of all links.
W_J_WL = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Inertial
)
W_J_WL = data.kyn_dyn.jacobian

# Compute the contact Jacobian.
# In inertial-fixed output representation, the Jacobian of the parent link is also
Expand Down Expand Up @@ -612,7 +612,7 @@ def jacobian_derivative(
]

# Get the transforms of all the parent links.
W_H_Li = js.model.forward_kinematics(model=model, data=data)
W_H_Li = data.kyn_dyn.forward_kinematics

# =====================================================
# Compute quantities to adjust the input representation
Expand Down Expand Up @@ -670,17 +670,9 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:

with data.switch_velocity_representation(VelRepr.Inertial):
# Compute the Jacobian of the parent link in inertial representation.
W_J_WL_W = js.model.generalized_free_floating_jacobian(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)
W_J_WL_W = data.kyn_dyn.jacobian
# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)
W_J̇_WL_W = data.kyn_dyn.jacobian_derivative

# Get the Jacobian of the enabled collidable points in the mixed representation.
with data.switch_velocity_representation(VelRepr.Mixed):
Expand Down
47 changes: 43 additions & 4 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
from typing_extensions import Self


@jax_dataclasses.pytree_dataclass
class KynDynComputation:

jacobian: jtp.Matrix

jacobian_derivative: jtp.Matrix

motion_subspaces: jtp.Matrix

joint_transforms: jtp.Matrix

mass_matrix: jtp.Matrix

forward_kinematics: jtp.Matrix


@jax_dataclasses.pytree_dataclass
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
"""
Expand All @@ -34,6 +50,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

state: ODEState

kyn_dyn: KynDynComputation

gravity: jtp.Vector

contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
Expand Down Expand Up @@ -232,11 +250,33 @@ def build(
else:
contacts_params = model.contact_model._parameters_class()

n = model.dofs()
n_fb = n + 6 * model.floating_base()

jacobian = jnp.zeros((model.number_of_links(), 6, n_fb))
jacobian_derivative = jnp.zeros((model.number_of_links(), 6, n_fb))
motion_subspaces = jnp.zeros((model.number_of_links(), 6, 1))
joint_transforms = jnp.zeros((model.number_of_links(), 6, 6))
mass_matrix = jnp.zeros((n_fb, n_fb))
forward_kinematics = jnp.zeros((model.number_of_links(), 4, 4))

kyn_dyn = KynDynComputation(
jacobian=jacobian,
jacobian_derivative=jacobian_derivative,
motion_subspaces=motion_subspaces,
joint_transforms=joint_transforms,
mass_matrix=mass_matrix,
forward_kinematics=forward_kinematics,
)

print(jacobian.shape)

return JaxSimModelData(
state=ode_state,
gravity=gravity,
contacts_params=contacts_params,
velocity_representation=velocity_representation,
kyn_dyn=kyn_dyn,
)

# ==================
Expand Down Expand Up @@ -349,8 +389,7 @@ def joint_velocities(

return self.state.physics_model.joint_velocities[joint_idxs]

@js.common.named_scope
@jax.jit
@property
def base_position(self) -> jtp.Vector:
"""
Get the base position.
Expand All @@ -359,7 +398,7 @@ def base_position(self) -> jtp.Vector:
The base position.
"""

return self.state.physics_model.base_position.squeeze()
return self.state.physics_model.base_position

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["dcm"])
Expand Down Expand Up @@ -400,7 +439,7 @@ def base_transform(self) -> jtp.Matrix:
"""

W_R_B = self.base_orientation(dcm=True)
W_p_B = jnp.vstack(self.base_position())
W_p_B = jnp.vstack(self.base_position)

return jnp.vstack(
[
Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def transform(
idx=link_index,
)

return js.model.forward_kinematics(model=model, data=data)[link_index]
return data.kyn_dyn.forward_kinematics[link_index]


@jax.jit
Expand Down Expand Up @@ -276,6 +276,8 @@ def jacobian(
B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
joint_transforms=data.kyn_dyn.joint_transforms,
motion_subspaces=data.kyn_dyn.motion_subspaces,
)

# Compute the actual doubly-left free-floating jacobian of the link.
Expand Down Expand Up @@ -422,9 +424,7 @@ def jacobian_derivative(
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative(
model=model, data=data, output_vel_repr=output_vel_repr
)[link_index]
O_J̇_WL_I = data.kyn_dyn.jacobian_derivative[link_index]

return O_J̇_WL_I

Expand Down
52 changes: 42 additions & 10 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def build(

integrator_cls = integrator
integrator = integrator_cls.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
system_dynamics=js.ode.system_dynamics
)
# dynamics=js.ode.wrap_system_dynamics_for_integration(
# system_dynamics=js.ode.system_dynamics
# )
)

case _:
Expand Down Expand Up @@ -574,9 +574,10 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp

W_H_LL = jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=data.base_position(),
base_position=data.base_position,
base_quaternion=data.base_orientation(dcm=False),
joint_positions=data.joint_positions(model=model),
joint_transforms=data.kyn_dyn.joint_transforms,
)

return jnp.atleast_3d(W_H_LL).astype(float)
Expand Down Expand Up @@ -616,6 +617,8 @@ def generalized_free_floating_jacobian(
B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
joint_transforms=data.kyn_dyn.joint_transforms,
motion_subspaces=data.kyn_dyn.motion_subspaces,
)

# ======================================================================
Expand Down Expand Up @@ -743,13 +746,17 @@ def generalized_free_floating_jacobian_derivative(
model=model,
joint_positions=data.joint_positions(),
joint_velocities=data.joint_velocities(),
# joint_transforms=data.kyn_dyn.joint_transforms,
# motion_subspaces=data.kyn_dyn.motion_subspaces,
)

# The derivative of the equation to change the input and output representations
# of the Jacobian derivative needs the computation of the plain link Jacobian.
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
joint_transforms=data.kyn_dyn.joint_transforms,
motion_subspaces=data.kyn_dyn.motion_subspaces,
)

# Compute the actual doubly-left free-floating jacobian derivative of the link
Expand Down Expand Up @@ -1005,7 +1012,7 @@ def forward_dynamics_aba(

# Extract the state in inertial-fixed representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_B = data.base_position()
W_p_B = data.base_position
W_v_WB = data.base_velocity()
W_Q_B = data.base_orientation(dcm=False)
s = data.joint_positions(model=model, joint_names=joint_names)
Expand All @@ -1031,6 +1038,8 @@ def forward_dynamics_aba(
joint_forces=τ,
link_forces=W_f_L,
standard_gravity=data.standard_gravity(),
joint_transforms=data.kyn_dyn.joint_transforms,
motion_subspaces=data.kyn_dyn.motion_subspaces,
)

# =============
Expand Down Expand Up @@ -1201,6 +1210,8 @@ def free_floating_mass_matrix(
M_body = jaxsim.rbda.crba(
model=model,
joint_positions=data.state.physics_model.joint_positions,
joint_transforms=data.kyn_dyn.joint_transforms,
motion_subspaces=data.kyn_dyn.motion_subspaces,
)

match data.velocity_representation:
Expand Down Expand Up @@ -1457,7 +1468,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):

# Extract the state in inertial-fixed representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_B = data.base_position()
W_p_B = data.base_position
W_v_WB = data.base_velocity()
W_Q_B = data.base_orientation(dcm=False)
s = data.joint_positions(model=model, joint_names=joint_names)
Expand All @@ -1484,6 +1495,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
joint_accelerations=,
link_forces=W_f_L,
standard_gravity=data.standard_gravity(),
joint_transforms=data.kyn_dyn.joint_transforms,
motion_subspaces=data.kyn_dyn.motion_subspaces,
)

# =============
Expand Down Expand Up @@ -1792,7 +1805,7 @@ def average_velocity_jacobian(
case VelRepr.Body:

GB_J = G_J
W_p_B = data.base_position()
W_p_B = data.base_position
W_p_CoM = js.com.com_position(model=model, data=data)
B_R_W = data.base_orientation(dcm=True).transpose()

Expand All @@ -1804,7 +1817,7 @@ def average_velocity_jacobian(
case VelRepr.Mixed:

GW_J = G_J
W_p_B = data.base_position()
W_p_B = data.base_position
W_p_CoM = js.com.com_position(model=model, data=data)

BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
Expand Down Expand Up @@ -2006,11 +2019,11 @@ def body_to_other_representation(
)

case VelRepr.Inertial:
C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
C_H_L = W_H_L = data.kyn_dyn.forward_kinematics
L_v_CL = L_v_WL

case VelRepr.Mixed:
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.kyn_dyn.forward_kinematics
LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
C_H_L = LW_H_L
L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
Expand Down Expand Up @@ -2182,6 +2195,25 @@ def forward(
joint_force_references: jtp.VectorLike | None = None,
) -> js.data.JaxSimModelData:

# Kinematics computation.
M = js.model.free_floating_mass_matrix(model=model, data=data)
J = js.model.generalized_free_floating_jacobian(model=model, data=data)
= js.model.generalized_free_floating_jacobian_derivative(model=model, data=data)
i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=data.joint_positions(), base_transform=data.base_transform()
)
FK = js.model.forward_kinematics(model=model, data=data)
kyn_dyn = js.data.KynDynComputation(
jacobian=J,
jacobian_derivative=,
joint_transforms=i_X_λ,
motion_subspaces=S,
mass_matrix=M,
forward_kinematics=FK,
)

data = data.replace(kyn_dyn=kyn_dyn)

# TODO: some contact models here may want to perform a dynamic filtering of
# the enabled collidable points.

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
)(W_f_L, W_H_L)

# The f_L output is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.kyn_dyn.forward_kinematics
f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])

return f_L
Expand Down Expand Up @@ -450,7 +450,7 @@ def convert_using_link_frame(
)(f_L, W_H_L)

# The f_L input is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.kyn_dyn.forward_kinematics
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])

return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
Expand Down
Loading

0 comments on commit 9a1cc27

Please sign in to comment.