Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impossible to run with JAX_DISABLE_JIT set to True model with zero dof #191

Closed
xela-95 opened this issue Jul 1, 2024 · 10 comments · Fixed by #219
Closed

Impossible to run with JAX_DISABLE_JIT set to True model with zero dof #191

xela-95 opened this issue Jul 1, 2024 · 10 comments · Fixed by #219
Assignees

Comments

@xela-95
Copy link
Member

xela-95 commented Jul 1, 2024

Related issue on JAX: jax-ml/jax#4668

The error is:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /home/acroci/repos/component_alpha/rigid_contacts_analytical.py:11
      [7](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:7) integration_time = 0.001
      [9](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:9) representation = jaxsim.VelRepr.Mixed
---> [11](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:11) data = js.data.JaxSimModelData.build(
     [12](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:12)     model=model,
     [13](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:13)     velocity_representation=representation,  # standard_gravity=7.0
     [14](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:14) )
     [15](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:15) # integrator = integrators.fixed_step.RungeKutta4SO3.build(
     [16](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:16) # integrator = integrators.fixed_step.ForwardEuler.build(
     [17](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:17) integrator = integrators.fixed_step.ForwardEulerSO3.build(
     [18](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:18)     dynamics=js.ode.wrap_system_dynamics_for_integration(
     [19](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:19)         model=model,
   (...)
     [25](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:25)     ),
     [26](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:26) )

File ~/repos/jaxsim/src/jaxsim/api/data.py:186, in JaxSimModelData.build(model, base_position, base_quaternion, joint_positions, base_linear_velocity, base_angular_velocity, joint_velocities, standard_gravity, contact, contacts_params, velocity_representation, time)
    [176](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:176) time_ns = (
    [177](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:177)     jnp.array(time * 1e9, dtype=jnp.uint64)
    [178](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:178)     if time is not None
    [179](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:179)     else jnp.array(0, dtype=jnp.uint64)
    [180](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:180) )
    [182](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:182) if isinstance(model.contact_model, SoftContacts):
    [183](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:183)     contacts_params = (
    [184](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:184)         contacts_params
    [185](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:185)         if contacts_params is not None
--> [186](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:186)         else js.contact.estimate_good_soft_contacts_parameters(
    [187](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:187)             model=model, standard_gravity=standard_gravity
    [188](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:188)         )
    [189](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:189)     )
    [190](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:190) else:
    [191](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:191)     contacts_params = model.contact_model.parameters

File ~/repos/jaxsim/src/jaxsim/api/contact.py:270, in estimate_good_soft_contacts_parameters(model, standard_gravity, static_friction_coefficient, number_of_active_collidable_points_steady_state, damping_ratio, max_penetration)
    [263](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:263)         return 2 * (W_pz_CoM - W_pz_C.min())
    [265](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:265)     return 2 * W_pz_CoM
    [267](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:267) max_δ = (
    [268](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:268)     max_penetration
    [269](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:269)     if max_penetration is not None
--> [270](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:270)     else 0.005 * estimate_model_height(model=model)
    [271](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:271) )
    [273](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:273) nc = number_of_active_collidable_points_steady_state
    [275](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:275) sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
    [276](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:276)     model=model,
    [277](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:277)     standard_gravity=standard_gravity,
   (...)
    [281](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:281)     damping_ratio=damping_ratio,
    [282](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:282) )

File ~/repos/jaxsim/src/jaxsim/api/contact.py:259, in estimate_good_soft_contacts_parameters.<locals>.estimate_model_height(model)
    [252](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:252) """"""
    [254](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:254) zero_data = js.data.JaxSimModelData.build(
    [255](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:255)     model=model,
    [256](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:256)     contacts_params=SoftContactsParams(),
    [257](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:257) )
--> [259](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:259) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
    [261](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:261) if model.floating_base():
    [262](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:262)     W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]

File ~/repos/jaxsim/src/jaxsim/api/com.py:29, in com_position(model, data)
     [16](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:16) """
     [17](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:17) Compute the position of the center of mass of the model.
     [18](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:18) 
   (...)
     [24](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:24)     The position of the center of mass of the model w.r.t. the world frame.
     [25](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:25) """
     [27](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:27) m = js.model.total_mass(model=model)
---> [29](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:29) W_H_L = js.model.forward_kinematics(model=model, data=data)
     [30](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:30) W_H_B = data.base_transform()
     [31](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:31) B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()

File ~/repos/jaxsim/src/jaxsim/api/model.py:441, in forward_kinematics(model, data)
    [427](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:427) @jax.jit
    [428](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:428) def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
    [429](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:429)     """
    [430](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:430)     Compute the SE(3) transforms from the world frame to the frames of all links.
    [431](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:431) 
   (...)
    [438](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:438)         The first axis is the link index.
    [439](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:439)     """
--> [441](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:441)     W_H_LL = jaxsim.rbda.forward_kinematics_model(
    [442](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:442)         model=model,
    [443](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:443)         base_position=data.base_position(),
    [444](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:444)         base_quaternion=data.base_orientation(dcm=False),
    [445](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:445)         joint_positions=data.joint_positions(model=model),
    [446](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:446)     )
    [448](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:448)     return jnp.atleast_3d(W_H_LL).astype(float)

File ~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:78, in forward_kinematics_model(model, base_position, base_quaternion, joint_positions)
     [74](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:74)     W_X_i = W_X_i.at[i].set(W_X_i_i)
     [76](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:76)     return (W_X_i,), None
---> [78](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:78) (W_X_i,), _ = jax.lax.scan(
     [79](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:79)     f=propagate_kinematics,
     [80](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:80)     init=propagate_kinematics_carry,
     [81](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:81)     xs=jnp.arange(start=1, stop=model.number_of_links()),
     [82](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:82) )
     [84](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:84) return jax.vmap(Adjoint.to_transform)(W_X_i)

    [... skipping hidden 1 frame]

File ~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:231, in scan(f, init, xs, length, reverse, unroll, _split_transpose)
    [229](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:229) if config.disable_jit.value:
    [230](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:230)   if length == 0:
--> [231](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:231)     raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
    [232](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:232)   carry = init
    [233](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:233)   ys = []

ValueError: zero-length scan is not supported in disable_jit() mode because the output type is unknown.
@diegoferigo
Copy link
Member

I already encountered this problem in the past for similar reasons. Refer to:

In that occasion, I fixed only the RBDAs that I needed. In your case, there are other ones that fail for similar reasons. Can you try to apply something similar to the following to exclude running the scan call?

(v, c, MA, pA, i_X_0), _ = (
jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(v, c, MA, pA, i_X_0), None]
)

@flferretti
Copy link
Collaborator

This can be added to #186

@diegoferigo
Copy link
Member

This can be added to #186

Probably it's time to merge that PR. We're already a bit further than what I define "minor changes", as it often happens 😄

@traversaro
Copy link
Contributor

I already encountered this problem in the past for similar reasons. Refer to:

* [Add new test suite of functional APIs #106](https://github.com/ami-iit/jaxsim/pull/106)

* [945f04b](https://github.com/ami-iit/jaxsim/commit/945f04b683c3519772ad4ec7bb916bacd4400a3f)

In that occasion, I fixed only the RBDAs that I needed. In your case, there are other ones that fail for similar reasons. Can you try to apply something similar to the following to exclude running the scan call?

(v, c, MA, pA, i_X_0), _ = (
jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(v, c, MA, pA, i_X_0), None]
)

Cool thanks, we had the intuition that a workaround something that was necessary, but we were a bit clueless. @xela-95 probably you can open a PR yourself with the fix proposed by @diegoferigo ?

@xela-95
Copy link
Member Author

xela-95 commented Jul 1, 2024

Cool thanks, we had the intuition that a workaround something that was necessary, but we were a bit clueless. @xela-95 probably you can open a PR yourself with the fix proposed by @diegoferigo ?

Sure, I'll try to see if this fixes the issue and then open a PR :)

@traversaro
Copy link
Contributor

xref other jax issues:

The fix suggested by @diegoferigo is useful and may be useful for users finding related jax issues on search engines.

@diegoferigo
Copy link
Member

diegoferigo commented Jul 1, 2024

It's worth noting that (if I don't mistake) the fix works in our case only because the condition of the if operates on a static element (following model, kin_dyn_parameters, link_names). I fear that it won't work if the condition cannot be evaluated statically. In that case, using jax.lax.cond might be necessary.

@traversaro
Copy link
Contributor

Most cases we saw were indeed due to model.number_of_links() > 1 (and probably went unnoticed as it is not so common to integrate a rigid body without joints).

@diegoferigo
Copy link
Member

Most cases we saw were indeed due to model.number_of_links() > 1 (and probably went unnoticed as it is not so common to integrate a rigid body without joints).

We actually do support that, and single-body models are also part of our test suite (together with a fixed-based and a floating-base model). This went unnoticed because JIT is automatically enabled in tests, and JIT-compiled jax.lax.scan do not complain if there is no actual iteration. They complain only if called either with JAX_DISABLE_JIT or inside a jax.disable_jit context.

@traversaro
Copy link
Contributor

Thanks @flferretti !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants