Skip to content

Commit

Permalink
Save some kindyn computation in JaxSimModelData
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 23, 2024
1 parent be7f69a commit db0fa1c
Show file tree
Hide file tree
Showing 26 changed files with 591 additions and 248 deletions.
6 changes: 3 additions & 3 deletions examples/jaxsim_as_physics_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
" )\n",
")(jnp.vstack(subkeys))\n",
"\n",
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])"
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])"
]
},
{
Expand Down Expand Up @@ -398,7 +398,7 @@
"# This operation is called 'tree transpose' in JAX.\n",
"data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
"\n",
"print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")"
"print(f\"W_p_B: shape={data_trajectory.base_position.shape}\")"
]
},
{
Expand All @@ -412,7 +412,7 @@
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n",
"plt.plot(T, data_trajectory.base_position[:, 0:5, 2])\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
Expand Down
2 changes: 2 additions & 0 deletions examples/jaxsim_for_robot_controllers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@
"# Reset the state to the random joint positions.\n",
"data = data_zero.reset_joint_positions(positions=random_joint_positions)\n",
"\n",
"# Update the kyn_dyn cache.\n",
"data = data.update_kyn_dyn(model=model)\n",
"\n",
"for _ in T:\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def com_position(

m = js.model.total_mass(model=model)

W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.kyn_dyn.forward_kinematics
W_H_B = data.base_transform()
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)

Expand Down Expand Up @@ -269,7 +269,7 @@ def bias_acceleration(
"""

# Compute the pose of all links with forward kinematics.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.kyn_dyn.forward_kinematics

# Compute the bias acceleration of all links by zeroing the generalized velocity
# in the active representation.
Expand Down
30 changes: 30 additions & 0 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,33 @@ def other_representation_to_inertial(

case _:
raise ValueError(other_representation)


def convert_mass_matrix(
M: jtp.Matrix,
base_transform: jtp.Matrix,
dofs: jtp.Int,
velocity_representation: VelRepr,
):
# TODO (flferretti): perform the velocity representation conversion here.
return M


def convert_jacobian(
J: jtp.Matrix,
base_transform: jtp.Matrix,
dofs: jtp.Int,
velocity_representation: VelRepr,
):
# TODO (flferretti): save actual Jacobian instead of full doubly left and perform conversion.
return J


def convert_jacobian_derivative(
Jd: jtp.Matrix,
base_transform: jtp.Matrix,
dofs: jtp.Int,
velocity_representation: VelRepr,
):
# TODO (flferretti): save actual Jacobian derivative instead of full doubly left and perform conversion.
return Jd
15 changes: 4 additions & 11 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,8 @@ def collidable_point_kinematics(
# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
with data.switch_velocity_representation(VelRepr.Inertial):

W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
base_position=data.base_position(),
base_quaternion=data.base_orientation(dcm=False),
joint_positions=data.joint_positions(model=model),
base_linear_velocity=data.base_velocity()[0:3],
base_angular_velocity=data.base_velocity()[3:6],
joint_velocities=data.joint_velocities(model=model),
)
W_p_Ci = data.kyn_dyn.collidable_point_positions
W_ṗ_Ci = data.kyn_dyn.collidable_point_velocities

return W_p_Ci, W_ṗ_Ci

Expand Down Expand Up @@ -460,7 +453,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
)[indices_of_enabled_collidable_points]

# Get the transforms of the parent link of all collidable points.
W_H_L = js.model.forward_kinematics(model=model, data=data)[
W_H_L = data.kyn_dyn.forward_kinematics[
parent_link_idx_of_enabled_collidable_points
]

Expand Down Expand Up @@ -612,7 +605,7 @@ def jacobian_derivative(
]

# Get the transforms of all the parent links.
W_H_Li = js.model.forward_kinematics(model=model, data=data)
W_H_Li = data.kyn_dyn.forward_kinematics

# =====================================================
# Compute quantities to adjust the input representation
Expand Down
Loading

0 comments on commit db0fa1c

Please sign in to comment.