Skip to content

cor3bit/somax

Repository files navigation

Somax

logo

Somax is a library of Second-Order Methods for stochastic optimization written in JAX. Somax is based on the JAXopt StochasticSolver API, and can be used as a drop-in replacement for JAXopt as well as Optax solvers.

Currently supported methods:

Future releases:

  • Add support for separate "gradient batches" and "curvature batches" for all solvers;
  • Add support for Optax rate schedules.

⚠️ Since JAXopt is currently being merged into Optax, Somax at some point will switch to the Optax API as well.

*The catfish in the logo is a nod to "сом", the Belarusian word for "catfish", also pronounced as "som".

Installation

pip install python-somax

Requires JAXopt 0.8.2+.

Quick example

from somax import EGN

# initialize the solver
solver = EGN(
    predict_fun=model.apply,
    loss_type='mse',
    learning_rate=0.1,
    regularizer=1.0,
)

# initialize the solver state
opt_state = solver.init_state(params)

# run the optimization loop
for i in range(10):
    params, opt_state = solver.update(params, opt_state, batch_x, batch_y)

See more in the examples folder.

Citation

@misc{korbit2024somax,
  author = {Nick Korbit},
  title = {{SOMAX}: a library of second-order methods for stochastic optimization written in {JAX}},
  year = {2024},
  url = {https://github.com/cor3bit/somax},
}

See also

Optimization with JAX
Optax: first-order gradient (SGD, Adam, ...) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.

Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list of resources for second-order optimization methods in machine learning.

Acknowledgements

Some of the implementation ideas are based on the following repositories: