opax
is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.
To install the latest version:
pip3 install git+https://github.com/ntt123/opax.git
To create an optimizer:
import opax
optimizer = opax.adam(1e-4).init(parameters)
Note: parameters
is a pytree of trainable parameters.
To update parameters:
updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)
Note: gradients
has the same treedef
as parameters
.
opax
follows optax
's philosophy in combining GradientTransformation
's together with opax.chain
function:
optimizer = opax.chain(
opax.clip_by_global_norm(1.0),
opax.scale_by_adam(),
opax.scale(1e-4),
).init(parameters)
opax
supports learning rate scheduling with opax.scale_by_schedule
.
def staircase_schedule_fn(step: jnp.ndarray):
p = jnp.floor(step.astype(jnp.float32) / 1000)
return jnp.exp2(-p)
optimizer = opax.chain(
opax.clip_by_global_norm(1.0),
opax.scale_by_adam(),
opax.scale_by_schedule(staircase_schedule_fn),
).init(parameters)
To get the current global norm:
print(optimizer[0].global_norm)
Note: global_norm
is a property of ClipByGlobalNorm
class.
To get the current learning rate:
print(optimizer[-1].learning_rate)
Note: learning_rate
is a property of ScaleBySchedule
class.