diff --git a/tensorflow_ranking/python/keras/layers.py b/tensorflow_ranking/python/keras/layers.py index 345eed4..5722228 100644 --- a/tensorflow_ranking/python/keras/layers.py +++ b/tensorflow_ranking/python/keras/layers.py @@ -68,13 +68,13 @@ def create_tower(hidden_layer_dims: List[int], if input_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) for layer_width in hidden_layer_dims: - model.add(tf.keras.layers.Dense(units=layer_width), **kwargs) + model.add(tf.keras.layers.Dense(units=layer_width, **kwargs)) if use_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) model.add(tf.keras.layers.Activation(activation=activation)) if dropout: model.add(tf.keras.layers.Dropout(rate=dropout)) - model.add(tf.keras.layers.Dense(units=output_units), **kwargs) + model.add(tf.keras.layers.Dense(units=output_units, **kwargs)) return model