functorch 0.1.0
We’re excited to announce the first beta release of functorch. Heavily inspired by Google JAX, functorch is a library that adds composable function transforms to PyTorch. It aims to provide composable vmap (vectorization) and autodiff transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.
Composable function transforms can help with a number of use cases that are tricky to do in PyTorch today:
- computing per-sample-gradients (or other per-sample quantities)
- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians as well as batched ones
- Composing vmap (vectorization), vjp (reverse-mode AD), and jvp (forward-mode AD) transforms allows us to effortlessly express the above without designing a separate library for each.
For more details, please see our documentation, tutorials, and installation instructions.