Skip to content

Commit

Permalink
Fix optimizer docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 3, 2024
1 parent 6ec0f46 commit 7004f52
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 15 deletions.
15 changes: 8 additions & 7 deletions keras/src/backend/jax/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""A class for JAX specific optimizer logic.
Its purpose is to route around statelessness
requirements in cond ops used for EMA handling
and gradient accumulation handling. We do this
by skipping conditionals entirely.
"""

import jax
from jax import numpy as jnp

from keras.src.optimizers import base_optimizer


class JaxOptimizer(base_optimizer.BaseOptimizer):
"""A class for JAX specific optimizer logic.
Its purpose is to route around statelessness
requirements in cond ops used for EMA handling
and gradient accumulation handling. We do this
by skipping conditionals entirely.
"""

def _backend_apply_gradients(self, grads, trainable_variables):
if self.gradient_accumulation_steps:
Expand Down
17 changes: 9 additions & 8 deletions keras/src/backend/tensorflow/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""A class for Tensorflow specific optimizer logic.
The major behavior change for this class is for tf.distribute.
It will override methods from base Keras core Optimizer,
which provide distribute specific functionality, e.g. variable
creation, loss reduction, etc.
"""

import warnings

import tensorflow as tf
Expand All @@ -9,14 +18,6 @@


class TFOptimizer(KerasAutoTrackable, base_optimizer.BaseOptimizer):
"""A class for Tensorflow specific optimizer logic.
The major behavior change for this class is for tf.distribute.
It will override methods from base Keras core Optimizer,
which provide distribute specific functionality, e.g. variable
creation, loss reduction, etc.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
53 changes: 53 additions & 0 deletions keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,59 @@


class BaseOptimizer(KerasSaveable):
"""Abstract optimizer base class.
If you intend to create your own optimization algorithm, please inherit from
this class and override the following methods:
- `build`: Create your optimizer-related variables, such as momentum
variables in the SGD optimizer.
- `update_step`: Implement your optimizer's variable updating logic.
- `get_config`: serialization of the optimizer.
Example:
```python
class SGD(Optimizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.momentum = 0.9
def build(self, variables):
super().build(variables)
self.momentums = []
for variable in variables:
self.momentums.append(
self.add_variable_from_reference(
reference_variable=variable, name="momentum"
)
)
def update_step(self, gradient, variable, learning_rate):
learning_rate = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
m = self.momentums[self._get_variable_index(variable)]
self.assign(
m,
ops.subtract(
ops.multiply(m, ops.cast(self.momentum, variable.dtype)),
ops.multiply(gradient, learning_rate),
),
)
self.assign_add(variable, m)
def get_config(self):
config = super().get_config()
config.update(
{
"momentum": self.momentum,
"nesterov": self.nesterov,
}
)
return config
```
"""

def __init__(
self,
learning_rate,
Expand Down
1 change: 1 addition & 0 deletions keras/src/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):
pass


Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__
base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args

0 comments on commit 7004f52

Please sign in to comment.