Skip to content

Commit

Permalink
Merge pull request #112 from andreped/batch-norm-fix
Browse files Browse the repository at this point in the history
Added method to replace BN layers [no ci]
  • Loading branch information
andreped committed Sep 8, 2023
2 parents c90109b + 068a507 commit d3fce33
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
--cov=gradient_accumulator tests/test_expected_result.py \
--cov=gradient_accumulator tests/test_mp_batch_norm.py \
--cov=gradient_accumulator tests/test_bn_convnd.py \
--cov=gradient_accumulator tests/test_bn_pretrained_swap.py \
--cov=gradient_accumulator tests/test_model_distribute.py
- name: Lint with flake8
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
- name: Install tensorflow-datasets
run: |
if [[ ${{ matrix.tf-version }} == 2.12 ]]; then
pip install tensorflow-datasets --upgrade
pip install "tensorflow-datasets<=4.9.2"
else
pip install tensorflow==${{ matrix.tf-version }} "tensorflow-datasets<=4.8.2"
pip install "protobuf<=3.20" --force-reinstall
Expand Down Expand Up @@ -96,6 +96,7 @@ jobs:
pytest -v tests/test_adaptive_gradient_clipping.py
pytest -v tests/test_batch_norm.py
pytest -v tests/test_bn_convnd.py
pytest -v tests/test_bn_pretrained_swap.py
pytest -v tests/test_mp_batch_norm.py
pytest -v tests/test_optimizer_distribute.py
pytest -v tests/test_model_distribute.py
Expand Down
22 changes: 22 additions & 0 deletions docs/examples/batch_normalization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ the *vanilla* batch normalization layer is the most used.

.. code-block:: python
import tensorflow as tf
from gradient_accumulator import GradientAccumulateModel, AccumBatchNormalization
# sets it here as we will set it for both the layer and model wrapper
Expand All @@ -32,6 +33,27 @@ the *vanilla* batch normalization layer is the most used.
model = GradientAccumulateModel(accum_steps=accum_steps, inputs=model.input, outputs=model.output)
You can also easily replace the existing Batch Norm layers in a
pretrained model, i.e., MobileNetV2. Below is an example on how to do that:


.. code-block:: python
import tensorflow as tf
from gradient_accumulator import GradientAccumulateModel
from gradient_accumulator.layers import AccumBatchNormalization
from gradient_accumulator.utils import replace_batchnorm_layers
accum_steps = 4
# replace BN layer with AccumBatchNormalization
model = tf.keras.applications.MobileNetV2(input_shape(28, 28, 3))
model = replace_batchnorm_layers(model, accum_steps=accum_steps)
# add gradient accumulation to existing model
model = GradientAccumulateModel(accum_steps=accum_steps, inputs=model.input, outputs=model.output)
Note that Batch Normalization is a unique layer in Keras.
It has two sets of variables. The first two `mean` and
`variance` are updated during the *forward pass*, whereas
Expand Down
75 changes: 75 additions & 0 deletions gradient_accumulator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import tensorflow as tf

from .layers import AccumBatchNormalization


def replace_batchnorm_layers(model, accum_steps, position="replace"):
# Auxiliary dictionary to describe the network graph
network_dict = {"input_layers_of": {}, "new_output_tensor_of": {}}

# Set the input layers of each layer
for layer in model.layers:
for node in layer._outbound_nodes:
layer_name = node.outbound_layer.name
if layer_name not in network_dict["input_layers_of"]:
network_dict["input_layers_of"].update(
{layer_name: [layer.name]}
)
else:
network_dict["input_layers_of"][layer_name].append(layer.name)

# Set the output tensor of the input layer
network_dict["new_output_tensor_of"].update(
{model.layers[0].name: model.input}
)

# Iterate over all layers after the input
model_outputs = []
iter_ = 0
for layer in model.layers[1:]:

# Determine input tensors
layer_input = [
network_dict["new_output_tensor_of"][layer_aux]
for layer_aux in network_dict["input_layers_of"][layer.name]
]
if len(layer_input) == 1:
layer_input = layer_input[0]

# Insert layer if name matches
if isinstance(layer, tf.keras.layers.BatchNormalization):
if position == "replace":
x = layer_input
else:
raise ValueError("position must be: replace")

# build new layer
new_layer = AccumBatchNormalization(
accum_steps=accum_steps,
name="AccumBatchNormalization_" + str(iter_),
)
new_layer.build(input_shape=layer.input_shape)

iter_ += 1

# set weights in new layer to match old layer
new_layer.accum_mean = layer.moving_mean
new_layer.moving_mean = layer.moving_mean

new_layer.accum_variance = layer.moving_variance
new_layer.moving_variance = layer.moving_variance

# forward step
x = new_layer(x)

else:
x = layer(layer_input)

# Set new output tensor (original one/the one of the inserted layer)
network_dict["new_output_tensor_of"].update({layer.name: x})

# Save tensor in output list if it is output in initial model
if layer_name in model.output_names:
model_outputs.append(x)

return tf.keras.Model(inputs=model.inputs, outputs=x)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="gradient-accumulator",
version="0.5.1",
version="0.5.2",
author="André Pedersen and David Bouget and Javier Pérez de Frutos and Tor-Arne Schmidt Nordmo",
author_email="[email protected]",
description="Package for gradient accumulation in TensorFlow",
Expand Down
85 changes: 85 additions & 0 deletions tests/test_bn_pretrained_swap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import random as python_random

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import load_model

from gradient_accumulator import GradientAccumulateModel
from gradient_accumulator.layers import AccumBatchNormalization
from gradient_accumulator.utils import replace_batchnorm_layers

from .utils import gray2rgb
from .utils import normalize_img
from .utils import reset
from .utils import resizeImage


def test_swap_layer(
custom_bn: bool = True, bs: int = 100, accum_steps: int = 1, epochs: int = 1
):
# load dataset
(ds_train, ds_test), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)

# build train pipeline
ds_train = ds_train.map(normalize_img)
ds_train = ds_train.map(gray2rgb)
ds_train = ds_train.map(resizeImage)
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(bs)
ds_train = ds_train.prefetch(1)

# build test pipeline
ds_test = ds_test.map(normalize_img)
ds_test = ds_test.map(gray2rgb)
ds_test = ds_test.map(resizeImage)
ds_test = ds_test.batch(bs)
ds_test = ds_test.prefetch(1)

# create model
base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3), weights="imagenet", include_top=False)
base_model = replace_batchnorm_layers(base_model, accum_steps=accum_steps)

input_ = tf.keras.layers.Input(shape=(32, 32, 3))
x = base_model(input_)
x = tf.keras.layers.Dense(10, activation="softmax")(x)
model = tf.keras.Model(inputs=input_, outputs=x)

# wrap model to use gradient accumulation
if accum_steps > 1:
model = GradientAccumulateModel(
accum_steps=accum_steps, inputs=model.input, outputs=model.output
)

# compile model
model.compile(
optimizer=tf.keras.optimizers.SGD(1e-2),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# train model
model.fit(
ds_train,
epochs=epochs,
validation_data=ds_test,
steps_per_epoch=4,
validation_steps=4,
)

model.save("./trained_model")

# load trained model and test
del model
trained_model = load_model("./trained_model", compile=True)

result = trained_model.evaluate(ds_test, verbose=1)
print(result)
return result
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255.0, label


def gray2rgb(image, label):
"""Converts images from gray to RGB."""
return tf.concat([image, image, image], axis=-1), label


def resizeImage(image, label, output_shape=(32, 32)):
"""Resizes images."""
return tf.image.resize(image, output_shape, method="nearest"), label


def run_experiment(bs=50, accum_steps=2, epochs=1, modeloropt="opt"):
# load dataset
(ds_train, ds_test), ds_info = tfds.load(
Expand Down

0 comments on commit d3fce33

Please sign in to comment.