Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor velocity representations as integers #160

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented May 23, 2024

This pull request refactors the velocity representation in the JaxSim API. The changes include avoiding the use of enum in VelRepr and making velocity_representation a non-static argument in some methods of jaxsim.api.JaxSimModelData. These changes improve the compatibility with JAX while avoiding breaking changes in the API, making it possible to potentially using jax.vmap on different velocity representations


📚 Documentation preview 📚: https://jaxsim--160.org.readthedocs.build//160/

@flferretti flferretti self-assigned this May 23, 2024
@flferretti flferretti force-pushed the feature/vmap_velrepr branch 3 times, most recently from 9bc8e2c to a262dda Compare May 23, 2024 13:34
@flferretti flferretti force-pushed the feature/vmap_velrepr branch 13 times, most recently from 40457a7 to 5d5ba73 Compare June 10, 2024 20:06
@flferretti flferretti marked this pull request as ready for review June 10, 2024 20:27
@flferretti flferretti requested a review from diegoferigo as a code owner June 10, 2024 20:27
@flferretti
Copy link
Collaborator Author

The AD tests take a bit longer to complete. This can be due to the JAX traceback used in JaxSimModelReferences that can lead to this effect when trying to compute the gradients

@diegoferigo
Copy link
Member

diegoferigo commented Jun 11, 2024

Awesome, thanks @flferretti for this PR. I'd expect the following consequences:

  • With static velocity representations, functions compiled passing different values are different for JAX. Only the section of the removed match-case statement is included in the compiled IR.
  • I expect now a slightly slower compilation of our functions since they include more code, but in most cases the match-case statements only contain a bunch of transforms computation, nothing too heavy. TL;DR slower JIT compilations for a single call, but no recompilations in case multiple representations are needed.
  • I expect, however, a smaller "equivalent binary" size of a large jit-compiled application. In fact, it happens to call the same JIT-compiled function with different representations (sometimes making calculation e.g. in body-fixed is easier even if the active representation is different), therefore with the old static representations the full binary might have contained multiple copies of the same RBDA where the only difference was the match-case entry used to adjust the representation. TL;DR now, smaller binary size that can also mean less memory used especially on GPU.

Not yet sure if my intuition is correct.

This being said, I'd like to tag a release with all the previous improvement. I don't expect surprises here, but being such a large change touching pretty much all our API surface, I prefer being cautious and include this PR in the following release (v0.4.0). I'll be reviewing this shortly.

@diegoferigo
Copy link
Member

The AD tests take a bit longer to complete.

I guess that now AD needs to propagate gradients through all possible branches instead of just one. And yes, this might take longer.

@flferretti
Copy link
Collaborator Author

I agree with your intuitions, the IR should now include the three branches for each velocity representation and the recompilation should not be triggered when we use switch_velocity_representation or from_other_to_inertial or from_inertial_to_other.

I'd like to tag a release with all the previous improvement. I don't expect surprises here, but being such a large change touching pretty much all our API surface, I prefer being cautious and include this PR in the following release (v0.4.0). I'll be reviewing this shortly.

I totally agree, this can be potentially disruptive, so I'd also prefer to be cautious and eventually rebase this onto #172

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just finished a first review pass, things look generally good. I'll going to request some changes, and make a second pass later. I still need to weight some consequences especially on AD performance.

Here below my comments / questions:

  • Most of the match-case seem to me 1:1 with the new lax.switch functions. Are there any sections in which you had to modify the original code? I checked the diff, but since it's quite large, I could have missed something.
  • In most cases, this PR does not alter significantly the readability of the code, I like that. We already had an extra indentation level due to the match-case statements, and the new functions use the same. There are cases, however, where the logic got much more complex and indented, like in JaxSimModelReferences. Let's keep this in mind in case we want to refactor it with a more simple approach.
  • Do you think it could be helpful adding a new jaxsim.typing.VelRepr variable that points to int? Since this removed enum are part of the core public APIs, I'd prefer to make clear that those are not generic integer types.

src/jaxsim/api/references.py Outdated Show resolved Hide resolved
src/jaxsim/api/references.py Show resolved Hide resolved
src/jaxsim/api/references.py Outdated Show resolved Hide resolved
@flferretti
Copy link
Collaborator Author

Most of the match-case seem to me 1:1 with the new lax.switch functions. Are there any sections in which you had to modify the original code? I checked the diff, but since it's quite large, I could have missed something.

No, they should be equivalent

In most cases, this PR does not alter significantly the readability of the code, I like that. We already had an extra indentation level due to the match-case statements, and the new functions use the same. There are cases, however, where the logic got much more complex and indented, like in JaxSimModelReferences. Let's keep this in mind in case we want to refactor it with a more simple approach.

Yes, the logic got more complex since I needed to somehow check the values of some parameters. We can think of a smarter solution to handle that in the future.

Do you think it could be helpful adding a new jaxsim.typing.VelRepr variable that points to int? Since this removed enum are part of the core public APIs, I'd prefer to make clear that those are not generic integer types.

Totally yes! I'll a commit for that

@flferretti flferretti force-pushed the feature/vmap_velrepr branch 2 times, most recently from 8398f0b to 88d5a2d Compare June 14, 2024 14:33
@flferretti flferretti requested a review from diegoferigo June 14, 2024 16:21
@flferretti flferretti force-pushed the feature/vmap_velrepr branch 2 times, most recently from b878836 to cec46de Compare June 17, 2024 08:24
flferretti and others added 25 commits August 23, 2024 07:47
In Articulated Body Algorithm

Co-authored-by: Alessandro Croci <[email protected]>
@flferretti flferretti force-pushed the feature/vmap_velrepr branch from dc175c9 to d580622 Compare August 23, 2024 05:47
@flferretti
Copy link
Collaborator Author

Checks are failing due to a timeout of the CI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants