-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Give the user the possibility of choosing what to JIT compile #301
Comments
Just to understand, did you tried if the same problem happens if we are using decorators over using |
That's a good observation! Yet, the behavior remains:
import jax
def fn(x: int, y:int):
def multiply(x:int, y:int):
return x * y
return multiply(x,y)
with jax.log_compiles()
jax.jit(fn)(4,3)
Finished tracing + transforming multiply for pjit in 0.00031876564025878906 sec
Finished tracing + transforming fn for pjit in 0.0007829666137695312 sec
Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) in 0.002622365951538086 sec
Finished XLA compilation of jit(fn) in 0.04105210304260254 sec
import jax
def fn(x: int, y:int):
@jax.jit
def multiply(x:int, y:int):
return x * y
return multiply(x,y)
with jax.log_compiles()
jax.jit(fn)(4,3)
Finished tracing + transforming multiply for pjit in 0.0002593994140625 sec
Finished tracing + transforming multiply for pjit in 0.0005724430084228516 sec
Finished tracing + transforming fn for pjit in 0.0009195804595947266 sec
Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) in 0.002263784408569336 sec
Finished XLA compilation of jit(fn) in 0.04090547561645508 sec As you can see, the function |
Check out the branch main...remove_jit for testing |
Watch out that in all your examples you are running the jit transformation on temporary objects. Try to decorate the outer function instead (maybe it goes aumatically to the cache the first run, but worth checking). |
Generally speaking, the need to have jit decorators on all APIs is because 1) people that use the project interactively do not need to remember to apply (and understand how to do it) the jit transformation; 2) it's the simplest way to define static function arguments, this cannot be left to the user as it introduces an additional burden. Removing only the decorators that do not use partial may introduce asymmetries in the APIs. |
Currently, we are decorating every method and function inside the
jaxsim.api
withjax.jit
. Yet, this introduces an overhead as the inner functions get compiled multiple times:While this can be nice when using JaxSim a multibody dynamics library, it can lead to unexpected result or additional overhead that could be removed.
FYI @traversaro @CarlottaSartore @diegoferigo @xela-95
The text was updated successfully, but these errors were encountered: