Skip to content
/ pax Public

A stateful pytree library for training neural networks.

License

Notifications You must be signed in to change notification settings

NTT123/pax

Repository files navigation

logo

Introduction | Getting started | Functional programming | Examples | Modules | Fine-tuning

pytest docs pypi

Introduction

PAX is a JAX-based library for training neural networks.

PAX modules are registered as JAX pytree, therefore, they can be input or output of JAX transformations such as jax.jit, jax.grad, etc. This makes programming with modules very convenient and easy to understand.

Installation

Install from PyPI:

pip install pax3

Or install the latest version from Github:

pip install git+https://github.com/ntt123/pax.git

## or test mode to run tests and examples
pip install git+https://github.com/ntt123/pax.git#egg=pax3[test]

Getting started

Below is a simple example of a Linear module.

import jax.numpy as jnp
import pax

class Linear(pax.Module):
    weight: jnp.ndarray
    bias: jnp.ndarray
    parameters = pax.parameters_method("weight", "bias")

    def __init__(self):
        super().__init__()
        self.weight = jnp.array(0.0)
        self.bias = jnp.array(0.0)

    def __call__(self, x):
        return self.weight * x + self.bias

The implementation is very similar to a normal python class. However, we need an additional line

    parameters = pax.parameters_method("weight", "bias")

to declare that weight and bias are trainable parameters of the Linear module.

PAX functional programming

pax.pure

A PAX module can have internal states. For example, below is a simple Counter module with an internal counter.

class Counter(pax.Module):
    count : jnp.ndarray

    def __init__(self):
        super().__init__()
        self.count = jnp.array(0)
    
    def __call__(self):
        self.count = self.count + 1
        return self.count

However, PAX aims to guarantee that modules will have no side effects from the outside point of view. Therefore, the modifications of these internal states are restricted. For example, we get an error when trying to call Counter directly.

counter = Counter()
count = counter()
# ...
# ----> 9         self.count = self.count + 1
# ...
# ValueError: Cannot modify a module in immutable mode.
# Please do this computation inside a function decorated by `pax.pure`.

Only functions decorated by pax.pure are allowed to modify input module's internal states.

@pax.pure
def update_counter(counter: Counter):
    count = counter()
    return counter, count

counter, count = update_counter(counter)
print(counter.count, count)
# 1 1

Note that we have to return counter in the output of update_counter, otherwise, the counter object will not be updated. This is because pax.pure only provides update_counter a copy of the counter object.

pax.purecall

For convenience, PAX provides the pax.purecall function. It is a shortcut for pax.pure(lambda f, x: [f, f(x)]).

Instead of implementing an update_counter function, we can do the same thing with:

counter, count = pax.purecall(counter)
print(counter.count, count)
# 2, 2

Replacing parts

PAX provides utility methods to modify a module in a functional way.

The replace method creates a new module with attributes replaced. For example, to replace weight and bias of a pax.Linear module:

fc = pax.Linear(2, 2)
fc = fc.replace(weight=jnp.ones((2,2)), bias=jnp.zeros((2,)))

The replace_node method replaces a pytree node of a module:

f = pax.Sequential(
    pax.Linear(2, 3),
    pax.Linear(3, 4),
)

f = f.replace_node(f[-1], pax.Linear(3, 5))
print(f.summary())
# Sequential
# ├── Linear(in_dim=2, out_dim=3, with_bias=True)
# └── Linear(in_dim=3, out_dim=5, with_bias=True)

PAX and other libraries

PAX learns a lot from other libraries:

  • PAX borrows the idea that a module is also a pytree from treex and equinox.
  • PAX uses the concept of trainable parameters and non-trainable states from dm-haiku.
  • PAX has similar methods to PyTorch such as model.apply(), model.parameters(), model.eval(), etc.
  • PAX uses objax's approach to implement optimizers as modules.
  • PAX uses jmp library for supporting mixed precision.
  • And of course, PAX is heavily influenced by jax functional programming approach.

Examples

A good way to learn about PAX is to see examples in the examples/ directory.

Click to expand
Path Description
char_rnn.py train a RNN language model on TPU.
transformer/ train a Transformer language model on TPU.
mnist.py train an image classifier on MNIST dataset.
notebooks/VAE.ipynb train a variational autoencoder.
notebooks/DCGAN.ipynb train a DCGAN model on Celeb-A dataset.
notebooks/fine_tuning_resnet18.ipynb finetune a pretrained ResNet18 model on cats vs dogs dataset.
notebooks/mixed_precision.ipynb train a U-Net image segmentation with mixed precision.
mnist_mixed_precision.py train an image classifier with mixed precision.
wave_gru/ train a WaveGRU vocoder: convert mel-spectrogram to waveform.
denoising_diffusion/ train a denoising diffusion model on Celeb-A dataset.

Modules

At the moment, PAX includes:

  • pax.Embed,
  • pax.Linear,
  • pax.{GRU, LSTM},
  • pax.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm},
  • pax.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose},
  • pax.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}.

Optimizers

PAX has its optimizers implemented in a separate library opax. The opax library supports many common optimizers such as adam, adamw, sgd, rmsprop. Visit opax's GitHub repository for more information.

Fine-tunning models

PAX's Module provides the pax.freeze_parameters transformation to convert all trainable parameters to non-trainable states.

net = pax.Sequential(
    pax.Linear(28*28, 64),
    jax.nn.relu,
    pax.Linear(64, 10),
)

net = pax.freeze_parameters(net) 
net = net.set(-1, pax.Linear(64, 2))

After this, net.parameters() will only return trainable parameters of the last layer.