To submit an ElasticDL job, a user needs to provide a model file, such as
mnist_functional_api.py
used in this
example.
This model file contains a model built with TensorFlow Keras API and other components required by ElasticDL, including dataset_fn, loss, optimizer, and eval_metrics_fn.
model
is a Keras model built using either TensorFlow Keras functional
API or model
subclassing.
The following example shows a model
using functional API, which has one input
with shape (28, 28), and one output with shape (10,):
inputs = tf.keras.Input(shape=(28, 28), name='image')
x = tf.keras.layers.Reshape((28, 28, 1))(inputs)
x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Dropout(0.25)(x)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(10)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')
Another example using model subclassing:
class MnistModel(tf.keras.Model):
def __init__(self):
super(MnistModel, self).__init__(name='mnist_model')
self._reshape = tf.keras.layers.Reshape((28, 28, 1))
self._conv1 = tf.keras.layers.Conv2D(
32, kernel_size=(3, 3), activation='relu')
self._conv2 = tf.keras.layers.Conv2D(
64, kernel_size=(3, 3), activation='relu')
self._batch_norm = tf.keras.layers.BatchNormalization()
self._maxpooling = tf.keras.layers.MaxPooling2D(
pool_size=(2, 2))
self._dropout = tf.keras.layers.Dropout(0.25)
self._flatten = tf.keras.layers.Flatten()
self._dense = tf.keras.layers.Dense(10)
def call(self, inputs, training=False):
x = self._reshape(inputs)
x = self._conv1(x)
x = self._conv2(x)
x = self._batch_norm(x, training=training)
x = self._maxpooling(x)
if training:
x = self._dropout(x, training=training)
x = self._flatten(x)
x = self._dense(x)
return x
model = MnistModel()
feed(dataset, mode)
feed
is a function that takes a RecordIO dataset
as input,
pre-processes the data as needed, and returns the a dataset containing
model_inputs
and labels
as a pair.
Argument:
- dataset: a RecordIO dataset generated by ElasticDL. ElasticDL creates a dataset by iterating records from RecordIO file.
- mode: This can be any values in defined
from elasticdl.python.common.constants.Mode
representing different phases such as training evaluation, and prediction. For example, ifmode == Mode.Prediction
, we don't need to return labels inside_parse_data()
.
Output: a dataset, each data is a tuple (model_inputs
, labels
)
model_inputs
is a dictionary of tensors, which will be used as
model input. labels
will be used as an input argument in
loss.
Example:
def feed(dataset, mode):
def _parse_data(record):
if mode == Mode.PREDICTION:
feature_description = {
"image": tf.io.FixedLenFeature([28, 28], tf.float32)
}
else:
feature_description = {
"image": tf.io.FixedLenFeature([28, 28], tf.float32),
"label": tf.io.FixedLenFeature([1], tf.int64),
}
r = tf.io.parse_single_example(record, feature_description)
features = {
"image": tf.math.divide(tf.cast(r["image"], tf.float32), 255.0)
}
if mode == Mode.PREDICTION:
return features
else:
return features, tf.cast(r["label"], tf.int32)
dataset = dataset.map(_parse_data)
if mode != Mode.PREDICTION:
dataset = dataset.shuffle(buffer_size=1024)
return dataset
loss(labels, predictions)
loss
is the loss function used in ElasticDL training.
Arguments:
Example:
def loss(labels, predictions):
return tf.reduce_mean(
input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=predictions, labels=labels.flatten()
)
)
optimizer()
optimizer
is a function returns a
tf.train.Optimizer
.
Example:
def optimizer(lr=0.1):
return tf.optimizers.SGD(lr)
eval_metrics_fn()
eval_metrics_fn
is a function that returns a dictionary where the key is name
of the evaluation metric and the value
is the evaluation metric result from the predictions
and labels
using
TensorFlow API.
Example:
def eval_metrics_fn():
return {
"accuracy": lambda labels, predictions: tf.equal(
tf.argmax(predictions, 1, output_type=tf.int32),
tf.cast(tf.reshape(labels, [-1]), tf.int32),
)
}