diff --git a/examples/PD_controller.ipynb b/examples/PD_controller.ipynb index 645955fba..38be2edf1 100644 --- a/examples/PD_controller.ipynb +++ b/examples/PD_controller.ipynb @@ -68,7 +68,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. " + "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", + "\n", + "- `model`: an object that defines the dynamics of the system.\n", + "- `data`: an object that contains the state of the system.\n", + "- `integrator`: an object that defines the integration method.\n", + "- `integrator_state`: an object that contains the state of the integrator." ] }, { @@ -77,11 +82,23 @@ "metadata": {}, "outputs": [], "source": [ - "from jaxsim.high_level.model import Model\n", + "import jaxsim.api as js\n", + "from jaxsim import integrators\n", + "\n", + "dt = 0.01\n", "\n", - "model = Model.build_from_model_description(\n", + "model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_urdf_string, is_urdf=True\n", - ")" + ")\n", + "data = js.data.JaxSimModelData.build(model=model)\n", + "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n", + " dynamics=js.ode.wrap_system_dynamics_for_integration(\n", + " model=model,\n", + " data=data,\n", + " system_dynamics=js.ode.system_dynamics,\n", + " ),\n", + ")\n", + "integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)" ] }, { @@ -101,7 +118,7 @@ " minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n", ")\n", "\n", - "model.reset_joint_positions(positions=random_positions)" + "data = data.reset_joint_positions(positions=random_positions)" ] }, { @@ -118,17 +135,11 @@ "outputs": [], "source": [ "# @title Set up MuJoCo renderer\n", - "!{sys.executable} -m pip install -U -q mujoco\n", - "!{sys.executable} -m pip install -q mediapy\n", "\n", - "import mediapy as media\n", - "import tempfile\n", - "import xml.etree.ElementTree as ET\n", - "import numpy as np\n", + "from jaxsim.mujoco.visualizer import MujocoVisualizer\n", + "from jaxsim.mujoco import RodModelToMjcf, MujocoModelHelper, MujocoVideoRecorder\n", + "from jaxsim.mujoco.loaders import UrdfToMjcf\n", "\n", - "import distutils.util\n", - "import os\n", - "import subprocess\n", "\n", "if IS_COLAB:\n", " if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n", @@ -171,66 +182,28 @@ " 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n", " )\n", "\n", + "camera = {\n", + " \"name\":\"cartpole_camera\",\n", + " \"mode\":\"fixed\",\n", + " \"pos\":\"3.954 3.533 2.343\",\n", + " \"xyaxes\":\"-0.594 0.804 -0.000 -0.163 -0.120 0.979\",\n", + " \"fovy\":\"60\",\n", + "}\n", "\n", - "def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):\n", - " def to_mjcf_string(list_to_str):\n", - " return \" \".join(map(str, list_to_str))\n", - "\n", - " mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n", - " path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n", - " mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n", - " # Add camera in mujoco model\n", - " tree = ET.parse(path_temp_xml)\n", - " for elem in tree.getroot().iter(\"worldbody\"):\n", - " worldbody_elem = elem\n", - " camera_elem = ET.Element(\"camera\")\n", - " # Set attributes\n", - " camera_elem.set(\"name\", \"side\")\n", - " camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n", - " camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n", - " camera_elem.set(\"mode\", \"fixed\")\n", - " worldbody_elem.append(camera_elem)\n", - "\n", - " # Save new model\n", - " mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n", - " mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n", - " return mj_model\n", - "\n", - "\n", - "def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):\n", - " mujocoqposaddr2jaxindex = {}\n", - " for jaxjnt in jaxsimmodel.joints():\n", - " jntname = jaxjnt.name()\n", - " mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1\n", - "\n", - " mujoco_jointpos = jaxsim_jointpos\n", - " for i in range(0, len(mujoco_jointpos)):\n", - " mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n", - "\n", - " return mujoco_jointpos\n", - "\n", - "\n", - "# To get a good camera location, you can use \"Copy camera\" functionality in MuJoCo GUI\n", - "mj_model = load_mujoco_model_with_camera(\n", - " model_urdf_string,\n", - " [3.954, 3.533, 2.343],\n", - " [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],\n", - ")\n", - "renderer = mujoco.Renderer(mj_model, height=480, width=640)\n", + "mjcf_string, assets = UrdfToMjcf.convert(urdf=model.built_from, cameras=camera)\n", "\n", + "mj_model_helper = MujocoModelHelper.build_from_xml(\n", + " mjcf_description=mjcf_string, assets=assets\n", + ")\n", "\n", - "def get_image(camera, mujocojointpos) -> np.ndarray:\n", - " \"\"\"Renders the environment state.\"\"\"\n", - " # Copy joint data in mjdata state\n", - " d = mujoco.MjData(mj_model)\n", - " d.qpos = mujocojointpos\n", - "\n", - " # Forward kinematics\n", - " mujoco.mj_forward(mj_model, d)\n", - "\n", - " # use the mjData object to update the renderer\n", - " renderer.update_scene(d, camera=camera)\n", - " return renderer.render()" + "# Create the video recorder.\n", + "recorder = MujocoVideoRecorder(\n", + " model=mj_model_helper.model,\n", + " data=mj_model_helper.data,\n", + " fps=int(1 / 0.010),\n", + " width=320 * 4,\n", + " height=240 * 4,\n", + ")" ] }, { @@ -246,24 +219,27 @@ "metadata": {}, "outputs": [], "source": [ - "from jaxsim.simulation.ode_integration import IntegratorType\n", - "\n", - "sim_images = []\n", - "timestep = 0.01\n", - "for _ in range(300):\n", - " sim_images.append(\n", - " get_image(\n", - " \"side\",\n", - " from_jaxsim_to_mujoco_pos(\n", - " np.array(model.joint_positions()), mj_model, model\n", - " ),\n", - " )\n", + "import mediapy as media\n", + "\n", + "for _ in range(500):\n", + " data, integrator_state = js.model.step(\n", + " dt=dt,\n", + " model=model,\n", + " data=data,\n", + " integrator=integrator,\n", + " integrator_state=integrator_state,\n", + " joint_forces=None,\n", + " link_forces=None,\n", " )\n", - " model.integrate(\n", - " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n", + "\n", + " mj_model_helper.set_joint_positions(\n", + " positions=data.joint_positions(), joint_names=model.joint_names()\n", " )\n", "\n", - "media.show_video(sim_images, fps=1 / timestep)" + " recorder.record_frame(camera_name=\"cartpole_camera\")\n", + "\n", + "media.show_video(recorder.frames, fps=1 / dt)\n", + "recorder.frames = []" ] }, { @@ -290,13 +266,17 @@ "KP = 10.0\n", "KD = 6.0\n", "\n", - "# Compute the gravity compensation term\n", - "H = model.free_floating_bias_forces()[6:]\n", - "\n", "\n", "def pd_controller(\n", - " q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array\n", + " data: js.data.JaxSimModelData, q_d: jax.Array, q_dot_d: jax.Array\n", ") -> jax.Array:\n", + "\n", + " # Compute the gravity compensation term\n", + " H = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n", + "\n", + " q = data.joint_positions()\n", + " q_dot = data.joint_velocities()\n", + "\n", " return H + KP * (q_d - q) + KD * (q_dot_d - q_dot)" ] }, @@ -313,31 +293,31 @@ "metadata": {}, "outputs": [], "source": [ - "sim_images = []\n", - "timestep = 0.01\n", - "\n", - "for _ in range(300):\n", - " sim_images.append(\n", - " get_image(\n", - " \"side\",\n", - " from_jaxsim_to_mujoco_pos(\n", - " np.array(model.joint_positions()), mj_model, model\n", - " ),\n", - " )\n", + "for _ in range(500):\n", + " control_torques = pd_controller(\n", + " data=data,\n", + " q_d=jnp.array([0.0, 0.0]),\n", + " q_dot_d=jnp.array([0.0, 0.0]),\n", " )\n", - " model.set_joint_generalized_force_targets(\n", - " forces=pd_controller(\n", - " q=model.joint_positions(),\n", - " q_d=jnp.array([0.0, 0.0]),\n", - " q_dot=model.joint_velocities(),\n", - " q_dot_d=jnp.array([0.0, 0.0]),\n", - " )\n", + "\n", + " data, integrator_state = js.model.step(\n", + " dt=dt,\n", + " model=model,\n", + " data=data,\n", + " integrator=integrator,\n", + " integrator_state=integrator_state,\n", + " joint_forces=control_torques,\n", + " link_forces=None,\n", " )\n", - " model.integrate(\n", - " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n", + "\n", + " mj_model_helper.set_joint_positions(\n", + " positions=data.joint_positions(), joint_names=model.joint_names()\n", " )\n", "\n", - "media.show_video(sim_images, fps=1 / timestep)" + " recorder.record_frame(camera_name=\"cartpole_camera\")\n", + "\n", + "media.show_video(recorder.frames, fps=1 / dt)\n", + "recorder.frames = []" ] } ], @@ -370,7 +350,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/examples/Parallel_computing.ipynb b/examples/Parallel_computing.ipynb index f7482824c..d06aac3ec 100644 --- a/examples/Parallel_computing.ipynb +++ b/examples/Parallel_computing.ipynb @@ -88,7 +88,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now, we can create a simulator instance and load the model into it." + "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", + "\n", + "- `model`: an object that defines the dynamics of the system.\n", + "- `data`: an object that contains the state of the system.\n", + "- `integrator`: an object that defines the integration method.\n", + "- `integrator_state`: an object that contains the state of the integrator." ] }, { @@ -97,29 +102,48 @@ "metadata": {}, "outputs": [], "source": [ - "from jaxsim.high_level.model import VelRepr\n", - "from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n", - "from jaxsim.simulation.ode_integration import IntegratorType\n", - "from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n", + "import jaxsim.api as js\n", + "from jaxsim import integrators\n", "\n", - "# Simulation Step Parameters\n", - "integration_time = 3.0 # seconds\n", - "step_size = 0.001\n", - "steps_per_run = 1\n", + "dt = 0.001\n", + "integration_time = 1500\n", "\n", - "simulator = JaxSim.build(\n", - " step_size=step_size,\n", - " steps_per_run=steps_per_run,\n", - " velocity_representation=VelRepr.Body,\n", - " integrator_type=IntegratorType.EulerSemiImplicit,\n", - " simulator_data=SimulatorData(\n", - " contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),\n", + "model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_sdf_string\n", + ")\n", + "data = js.data.JaxSimModelData.build(model=model)\n", + "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n", + " dynamics=js.ode.wrap_system_dynamics_for_integration(\n", + " model=model,\n", + " data=data,\n", + " system_dynamics=js.ode.system_dynamics,\n", " ),\n", ")\n", + "integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It is possible to automatically choose a good set of parameters for the terrain. \n", "\n", + "By default, in JaxSim a sphere primitive has 250 collision points. This can be modified by setting the `JAXSIM_COLLISION_SPHERE_POINTS` environment variable.\n", "\n", - "# Add model to simulator\n", - "model = simulator.insert_model_from_description(model_description=model_sdf_string)" + "Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = data.replace(\n", + " soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n", + " model, number_of_active_collidable_points_steady_state=3\n", + " )\n", + ")" ] }, { @@ -136,8 +160,9 @@ "outputs": [], "source": [ "# Primary Calculations\n", + "envs_per_row = 4 # @slider(2, 10, 1)\n", + "\n", "env_spacing = 0.5\n", - "envs_per_row = 3\n", "edge_len = env_spacing * (2 * envs_per_row - 1)\n", "\n", "\n", @@ -155,6 +180,7 @@ " return jnp.array(poses)\n", "\n", "\n", + "logging.info(f\"Simulating {envs_per_row**2} environments\")\n", "poses = grid(edge_len, envs_per_row)" ] }, @@ -162,9 +188,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.\n", - "\n", - "**Note:** [`step_over_horizon`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L432C1-L529C10) is useful only in open-loop simulations and where the horizon is known in advance. Please checkout [`step`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L384C10-L425) for closed-loop simulations." + "In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch." ] }, { @@ -173,35 +197,27 @@ "metadata": {}, "outputs": [], "source": [ - "from jaxsim.simulation import simulator_callbacks\n", - "\n", - "\n", - "# Create a logger to store simulation data\n", - "@jax_dataclasses.pytree_dataclass\n", - "class SimulatorLogger(simulator_callbacks.PostStepCallback):\n", - " def post_step(\n", - " self, sim: JaxSim, step_data: Dict[str, StepData]\n", - " ) -> Tuple[JaxSim, jtp.PyTree]:\n", - " \"\"\"Return the StepData object of each simulated model\"\"\"\n", - " return sim, step_data\n", - "\n", - "\n", "# Define a function to simulate a single model instance\n", - "def simulate(sim: JaxSim, pose) -> JaxSim:\n", - " model.zero()\n", - " model.reset_base_position(position=jnp.array(pose))\n", - "\n", - " with sim.editable(validate=True) as sim:\n", - " m = sim.get_model(model.name())\n", - " m.data = model.data\n", - "\n", - " sim, (cb, (_, step_data)) = simulator.step_over_horizon(\n", - " horizon_steps=integration_time // step_size,\n", - " callback_handler=SimulatorLogger(),\n", - " clear_inputs=True,\n", - " )\n", - "\n", - " return step_data" + "def simulate(\n", + " data: js.data.JaxSimModelData, integrator_state: dict, pose: jnp.array\n", + ") -> tuple:\n", + "\n", + " data = data.reset_base_position(base_position=pose)\n", + " x_t_i = []\n", + "\n", + " for _ in range(integration_time):\n", + " data, integrator_state = js.model.step(\n", + " dt=dt,\n", + " model=model,\n", + " data=data,\n", + " integrator=integrator,\n", + " integrator_state=integrator_state,\n", + " joint_forces=None,\n", + " link_forces=None,\n", + " )\n", + " x_t_i.append(data.base_position())\n", + "\n", + " return x_t_i" ] }, { @@ -213,7 +229,7 @@ "\n", "Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n", "\n", - "`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension." + "`in_axes=(None, None, 0)` means that the first two arguments of `simulate` are not vectorized, while the third argument is vectorized over the zero-th dimension." ] }, { @@ -223,12 +239,12 @@ "outputs": [], "source": [ "# Define a function to simulate multiple model instances\n", - "simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))\n", + "simulate_vectorized = jax.vmap(simulate, in_axes=(None, None, 0))\n", "\n", "# Run and time the simulation\n", "now = time.perf_counter()\n", "\n", - "time_history = simulate_vectorized(simulator, poses[:, 0])\n", + "x_t = simulate_vectorized(data, integrator_state, poses[:, 0])\n", "\n", "comp_time = time.perf_counter() - now\n", "\n", @@ -236,7 +252,7 @@ " f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n", ")\n", "logging.info(\n", - " f\"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}\"\n", + " f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 *integration_time/comp_time):.2f}\"\n", ")" ] }, @@ -253,13 +269,10 @@ "metadata": {}, "outputs": [], "source": [ - "time_history: Dict[str, StepData]\n", - "x_t = time_history[model.name()].tf_model_state\n", - "\n", - "\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", - "plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)\n", + "plt.plot(np.arange(len(x_t)) * dt, np.array(x_t)[:, :, 2])\n", "plt.grid(True)\n", "plt.xlabel(\"Time [s]\")\n", "plt.ylabel(\"Height [m]\")\n", @@ -297,7 +310,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.1" + "version": "3.11.8" } }, "nbformat": 4,