Skip to content

Commit

Permalink
Merge pull request #132 from ami-iit/update_examples_functional
Browse files Browse the repository at this point in the history
Update notebook examples to the functional API
  • Loading branch information
flferretti authored Apr 3, 2024
2 parents 546e91f + 9ca6ebc commit e917034
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 172 deletions.
204 changes: 92 additions & 112 deletions examples/PD_controller.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. "
"JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
"- `integrator`: an object that defines the integration method.\n",
"- `integrator_state`: an object that contains the state of the integrator."
]
},
{
Expand All @@ -77,11 +82,23 @@
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.high_level.model import Model\n",
"import jaxsim.api as js\n",
"from jaxsim import integrators\n",
"\n",
"dt = 0.01\n",
"\n",
"model = Model.build_from_model_description(\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_urdf_string, is_urdf=True\n",
")"
")\n",
"data = js.data.JaxSimModelData.build(model=model)\n",
"integrator = integrators.fixed_step.RungeKutta4SO3.build(\n",
" dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
" model=model,\n",
" data=data,\n",
" system_dynamics=js.ode.system_dynamics,\n",
" ),\n",
")\n",
"integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)"
]
},
{
Expand All @@ -101,7 +118,7 @@
" minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n",
")\n",
"\n",
"model.reset_joint_positions(positions=random_positions)"
"data = data.reset_joint_positions(positions=random_positions)"
]
},
{
Expand All @@ -118,17 +135,11 @@
"outputs": [],
"source": [
"# @title Set up MuJoCo renderer\n",
"!{sys.executable} -m pip install -U -q mujoco\n",
"!{sys.executable} -m pip install -q mediapy\n",
"\n",
"import mediapy as media\n",
"import tempfile\n",
"import xml.etree.ElementTree as ET\n",
"import numpy as np\n",
"from jaxsim.mujoco.visualizer import MujocoVisualizer\n",
"from jaxsim.mujoco import RodModelToMjcf, MujocoModelHelper, MujocoVideoRecorder\n",
"from jaxsim.mujoco.loaders import UrdfToMjcf\n",
"\n",
"import distutils.util\n",
"import os\n",
"import subprocess\n",
"\n",
"if IS_COLAB:\n",
" if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n",
Expand Down Expand Up @@ -171,66 +182,28 @@
" 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
" )\n",
"\n",
"camera = {\n",
" \"name\":\"cartpole_camera\",\n",
" \"mode\":\"fixed\",\n",
" \"pos\":\"3.954 3.533 2.343\",\n",
" \"xyaxes\":\"-0.594 0.804 -0.000 -0.163 -0.120 0.979\",\n",
" \"fovy\":\"60\",\n",
"}\n",
"\n",
"def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):\n",
" def to_mjcf_string(list_to_str):\n",
" return \" \".join(map(str, list_to_str))\n",
"\n",
" mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n",
" path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n",
" mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n",
" # Add camera in mujoco model\n",
" tree = ET.parse(path_temp_xml)\n",
" for elem in tree.getroot().iter(\"worldbody\"):\n",
" worldbody_elem = elem\n",
" camera_elem = ET.Element(\"camera\")\n",
" # Set attributes\n",
" camera_elem.set(\"name\", \"side\")\n",
" camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n",
" camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n",
" camera_elem.set(\"mode\", \"fixed\")\n",
" worldbody_elem.append(camera_elem)\n",
"\n",
" # Save new model\n",
" mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n",
" mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n",
" return mj_model\n",
"\n",
"\n",
"def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):\n",
" mujocoqposaddr2jaxindex = {}\n",
" for jaxjnt in jaxsimmodel.joints():\n",
" jntname = jaxjnt.name()\n",
" mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1\n",
"\n",
" mujoco_jointpos = jaxsim_jointpos\n",
" for i in range(0, len(mujoco_jointpos)):\n",
" mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n",
"\n",
" return mujoco_jointpos\n",
"\n",
"\n",
"# To get a good camera location, you can use \"Copy camera\" functionality in MuJoCo GUI\n",
"mj_model = load_mujoco_model_with_camera(\n",
" model_urdf_string,\n",
" [3.954, 3.533, 2.343],\n",
" [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],\n",
")\n",
"renderer = mujoco.Renderer(mj_model, height=480, width=640)\n",
"mjcf_string, assets = UrdfToMjcf.convert(urdf=model.built_from, cameras=camera)\n",
"\n",
"mj_model_helper = MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mjcf_string, assets=assets\n",
")\n",
"\n",
"def get_image(camera, mujocojointpos) -> np.ndarray:\n",
" \"\"\"Renders the environment state.\"\"\"\n",
" # Copy joint data in mjdata state\n",
" d = mujoco.MjData(mj_model)\n",
" d.qpos = mujocojointpos\n",
"\n",
" # Forward kinematics\n",
" mujoco.mj_forward(mj_model, d)\n",
"\n",
" # use the mjData object to update the renderer\n",
" renderer.update_scene(d, camera=camera)\n",
" return renderer.render()"
"# Create the video recorder.\n",
"recorder = MujocoVideoRecorder(\n",
" model=mj_model_helper.model,\n",
" data=mj_model_helper.data,\n",
" fps=int(1 / 0.010),\n",
" width=320 * 4,\n",
" height=240 * 4,\n",
")"
]
},
{
Expand All @@ -246,24 +219,27 @@
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.simulation.ode_integration import IntegratorType\n",
"\n",
"sim_images = []\n",
"timestep = 0.01\n",
"for _ in range(300):\n",
" sim_images.append(\n",
" get_image(\n",
" \"side\",\n",
" from_jaxsim_to_mujoco_pos(\n",
" np.array(model.joint_positions()), mj_model, model\n",
" ),\n",
" )\n",
"import mediapy as media\n",
"\n",
"for _ in range(500):\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=None,\n",
" link_forces=None,\n",
" )\n",
" model.integrate(\n",
" t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
"\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions(), joint_names=model.joint_names()\n",
" )\n",
"\n",
"media.show_video(sim_images, fps=1 / timestep)"
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"media.show_video(recorder.frames, fps=1 / dt)\n",
"recorder.frames = []"
]
},
{
Expand All @@ -290,13 +266,17 @@
"KP = 10.0\n",
"KD = 6.0\n",
"\n",
"# Compute the gravity compensation term\n",
"H = model.free_floating_bias_forces()[6:]\n",
"\n",
"\n",
"def pd_controller(\n",
" q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array\n",
" data: js.data.JaxSimModelData, q_d: jax.Array, q_dot_d: jax.Array\n",
") -> jax.Array:\n",
"\n",
" # Compute the gravity compensation term\n",
" H = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n",
"\n",
" q = data.joint_positions()\n",
" q_dot = data.joint_velocities()\n",
"\n",
" return H + KP * (q_d - q) + KD * (q_dot_d - q_dot)"
]
},
Expand All @@ -313,31 +293,31 @@
"metadata": {},
"outputs": [],
"source": [
"sim_images = []\n",
"timestep = 0.01\n",
"\n",
"for _ in range(300):\n",
" sim_images.append(\n",
" get_image(\n",
" \"side\",\n",
" from_jaxsim_to_mujoco_pos(\n",
" np.array(model.joint_positions()), mj_model, model\n",
" ),\n",
" )\n",
"for _ in range(500):\n",
" control_torques = pd_controller(\n",
" data=data,\n",
" q_d=jnp.array([0.0, 0.0]),\n",
" q_dot_d=jnp.array([0.0, 0.0]),\n",
" )\n",
" model.set_joint_generalized_force_targets(\n",
" forces=pd_controller(\n",
" q=model.joint_positions(),\n",
" q_d=jnp.array([0.0, 0.0]),\n",
" q_dot=model.joint_velocities(),\n",
" q_dot_d=jnp.array([0.0, 0.0]),\n",
" )\n",
"\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=control_torques,\n",
" link_forces=None,\n",
" )\n",
" model.integrate(\n",
" t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
"\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions(), joint_names=model.joint_names()\n",
" )\n",
"\n",
"media.show_video(sim_images, fps=1 / timestep)"
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"media.show_video(recorder.frames, fps=1 / dt)\n",
"recorder.frames = []"
]
}
],
Expand Down Expand Up @@ -370,7 +350,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit e917034

Please sign in to comment.