-
Notifications
You must be signed in to change notification settings - Fork 61
/
train_model.py
78 lines (60 loc) · 2.56 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import uuid
import numpy as np
import tensorflow as tf
import valohai
def log_metadata(epoch, logs):
"""Helper function to log training metrics"""
with valohai.logger() as logger:
logger.log('epoch', epoch)
logger.log('accuracy', logs['accuracy'])
logger.log('loss', logs['loss'])
def main():
# valohai.prepare enables us to update the valohai.yaml configuration file with
# the Valohai command-line client by running `valohai yaml step train_model.py`
valohai.prepare(
step='train-model',
image='tensorflow/tensorflow:2.6.0',
default_inputs={
'dataset': 'https://valohaidemo.blob.core.windows.net/mnist/preprocessed_mnist.npz',
},
default_parameters={
'learning_rate': 0.001,
'epochs': 5,
},
)
# Read input files from Valohai inputs directory
# This enables Valohai to version your training data
# and cache the data for quick experimentation
input_path = valohai.inputs('dataset').path()
with np.load(input_path, allow_pickle=True) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10),
])
optimizer = tf.keras.optimizers.Adam(learning_rate=valohai.parameters('learning_rate').value)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer,
loss=loss_fn,
metrics=['accuracy'])
# Print metrics out as JSON
# This enables Valohai to version your metadata
# and for you to use it to compare experiments
callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_metadata)
model.fit(x_train, y_train, epochs=valohai.parameters('epochs').value, callbacks=[callback])
# Evaluate the model and print out the test metrics as JSON
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=2)
with valohai.logger() as logger:
logger.log('test_accuracy', test_accuracy)
logger.log('test_loss', test_loss)
# Write output files to Valohai outputs directory
# This enables Valohai to version your data
# and upload output it to the default data store
suffix = uuid.uuid4()
output_path = valohai.outputs().path(f'model-{suffix}.h5')
model.save(output_path)
if __name__ == '__main__':
main()