Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give the user the possibility of choosing what to JIT compile #301

Open
flferretti opened this issue Nov 27, 2024 · 5 comments
Open

Give the user the possibility of choosing what to JIT compile #301

flferretti opened this issue Nov 27, 2024 · 5 comments
Assignees

Comments

@flferretti
Copy link
Collaborator

Currently, we are decorating every method and function inside the jaxsim.api with jax.jit. Yet, this introduces an overhead as the inner functions get compiled multiple times:

Single JIT Multiple JIT
>>> import jax
>>> 
>>> def fn(x: int, y:int):
>>>     return x * y
>>> 
>>> with jax.log_compiles():
>>>     jax.jit(fn)(4, 3)

Finished tracing + transforming multiply for pjit
    in 0.0005085468292236328 sec
Finished tracing + transforming fn for pjit 
    in 0.0011935234069824219 sec








Compiling fn for with global shapes and types
    [
    ShapedArray(int32[], weak_type=True), 
    ShapedArray(int32[], weak_type=True)
    ]. 
Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) 
    in 0.0018908977508544922 sec
Finished XLA compilation of jit(fn)
    in 0.03702688217163086 sec
>>> import jax
>>> 
>>> def fn(x: int, y:int):
>>>     return x * y
>>> 
>>> with jax.log_compiles():
>>>     jax.jit(jax.jit(jax.jit(jax.jit(jax.jit(fn)))))(4, 3)

Finished tracing + transforming multiply for pjit 
    in 0.0003635883331298828 sec
Finished tracing + transforming fn for pjit 
    in 0.0006585121154785156 sec
Finished tracing + transforming fn for pjit 
    in 0.0009508132934570312 sec
Finished tracing + transforming fn for pjit 
    in 0.0011334419250488281 sec
Finished tracing + transforming fn for pjit 
    in 0.0013675689697265625 sec
Finished tracing + transforming fn for pjit 
    in 0.0017311573028564453 sec
Compiling fn for with global shapes and types 
    [
    ShapedArray(int32[], weak_type=True), 
    ShapedArray(int32[], weak_type=True)
    ]. 
Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) 
    in 0.0028264522552490234 sec
Finished XLA compilation of jit(fn) 
    in 0.040222883224487305 sec

While this can be nice when using JaxSim a multibody dynamics library, it can lead to unexpected result or additional overhead that could be removed.

FYI @traversaro @CarlottaSartore @diegoferigo @xela-95

@flferretti flferretti self-assigned this Nov 27, 2024
@traversaro
Copy link
Contributor

Just to understand, did you tried if the same problem happens if we are using decorators over using jax.jit function directly? I would expect the use of decorator and functions to do the same, but you know that the devil is in the details.

@flferretti
Copy link
Collaborator Author

flferretti commented Nov 27, 2024

That's a good observation! Yet, the behavior remains:

  • Without jax.jit decorator:
import jax

def fn(x: int, y:int):

    def multiply(x:int, y:int):
        return x * y

    return multiply(x,y)

with jax.log_compiles()
    jax.jit(fn)(4,3)

Finished tracing + transforming multiply for pjit in 0.00031876564025878906 sec
Finished tracing + transforming fn for pjit in 0.0007829666137695312 sec
Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) in 0.002622365951538086 sec
Finished XLA compilation of jit(fn) in 0.04105210304260254 sec
  • With jax.jit decorator:
import jax

def fn(x: int, y:int):
    
    @jax.jit
    def multiply(x:int, y:int):
        return x * y

    return multiply(x,y)

with jax.log_compiles()
    jax.jit(fn)(4,3)

Finished tracing + transforming multiply for pjit in 0.0002593994140625 sec
Finished tracing + transforming multiply for pjit in 0.0005724430084228516 sec
Finished tracing + transforming fn for pjit in 0.0009195804595947266 sec
Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) in 0.002263784408569336 sec
Finished XLA compilation of jit(fn) in 0.04090547561645508 sec

As you can see, the function multiply is compiled twice

@flferretti
Copy link
Collaborator Author

Check out the branch main...remove_jit for testing

@diegoferigo
Copy link
Member

diegoferigo commented Nov 29, 2024

Watch out that in all your examples you are running the jit transformation on temporary objects. Try to decorate the outer function instead (maybe it goes aumatically to the cache the first run, but worth checking).

@diegoferigo
Copy link
Member

diegoferigo commented Nov 29, 2024

Generally speaking, the need to have jit decorators on all APIs is because 1) people that use the project interactively do not need to remember to apply (and understand how to do it) the jit transformation; 2) it's the simplest way to define static function arguments, this cannot be left to the user as it introduces an additional burden.

Removing only the decorators that do not use partial may introduce asymmetries in the APIs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

3 participants