The callbacks in tf.keras
is a class, usually specified as a parameter when use model.fit
. It provides the extra operations at the starting or the ending of training, each epoch or each batch. These operations include record some log information, change learning rate, early termination of the training, etc.
Likewise, this callbacks parameter is also able to be specified for model.evaluate
or model.predict
, providing extra operations at the starting or the ending of the evaluation, prediction, or each batch. However this method is rarely used.
For the most cases, the pre-defined callbacks in the sub-module keras.callbacks
are sufficient. It is also possible to define child class inheriting keras.callbacks.Callbacks
to customize callbacks if necessary.
All the classes of callbacks are inheriting keras.callbacks.Callbacks
, which contains two attributes: params
and model
.
params
is a dictionary, which records training parameters (e.g. verbosity, batch size, number of epochs, etc.). model
is the reference to the current model.
What's more, there is an extra argument logs
in the certain methods of the callbacks classes, such as on_epoch_begin
, on_batch_end
. This parameter provides certain information of current epoch or batch and are able to save the computing results. These logs
variables are able to transfer among the functions with the same name in these callbacks classes.
-
BaseLogger
: it calcuates the mean metrics among all batches for each epoch. For those metrics with middle status instaeful_metrics
, it uses the final metrics without calculating mean value for all the batches, and the final mean metrics is added to the variablelogs
. This callback is automatically applied to every Keras model and is applied first. -
History
: a dictionary that records the metrics of each epoch calculated byBaseLogger
and is returned bymodel.fit
. This callback is automatically applied to every Keras model afterBaseLogger
. -
EarlyStopping
: this callback terminates the training if the monitoring metrics are not significantly increased after certain number of epoches. -
TensorBoard
: this callback saves the visualized log of the Tensorboard. It supports visualization of metrics, graphs and parameters in the model. -
ModelCheckpoint
: this callback saves model after each epoch. -
ReduceLROnPlateau
: this callback reduce the learning rate with certain rate if the monitoring metrics are not significantly increased after certain number of epoches. -
TerminateOnNaN
: terminate the training if loss is NaN. -
LearningRateScheduler
: it controls the learning rate before each epoch with given function between the learning ratelr
and epoch. -
CSVLogger
: savelogs
of each epoch in CSV file. -
ProgbarLogger
: print thelogs
of each epoch into stardard I/O stream.
It is possible to write a simple callback through callbacks.LambdaCallback
, or write a complicated callback through inheriting base class callbacks.Callback
.
Don't hesitate to read the source code to know more details of the callbacks in tf.Keras
.
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics,callbacks
import tensorflow.keras.backend as K
# Example of the simple callback using LambdaCallback
import json
json_log = open('../data/keras_log.json', mode='wt', buffering=1)
json_logging_callback = callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: json_log.write(
json.dumps(dict(epoch = epoch,**logs)) + '\n'),
on_train_end=lambda logs: json_log.close()
)
# Example of the complicated callback through base class inheritance. This is the source code of LearningRateScheduler.
class LearningRateScheduler(callbacks.Callback):
def __init__(self, schedule, verbose=0):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
try:
lr = float(K.get_value(self.model.optimizer.lr))
lr = self.schedule(epoch, lr)
except TypeError: # Support for old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
raise ValueError('The dtype of Tensor should be float')
K.set_value(self.model.optimizer.lr, K.get_value(lr))
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)
Please leave comments in the WeChat official account "Python与算法之美" (Elegance of Python and Algorithms) if you want to communicate with the author about the content. The author will try best to reply given the limited time available.
You are also welcomed to join the group chat with the other readers through replying 加群 (join group) in the WeChat official account.