From 74599d9b567b481c52b37a19389da8dc1cbe74b0 Mon Sep 17 00:00:00 2001 From: Zhihong Duan Date: Thu, 24 Jun 2021 16:22:27 -0700 Subject: [PATCH 1/2] passing kwargs into Layer class instead of model.add --- tensorflow_ranking/python/keras/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_ranking/python/keras/layers.py b/tensorflow_ranking/python/keras/layers.py index 345eed4..a740b65 100644 --- a/tensorflow_ranking/python/keras/layers.py +++ b/tensorflow_ranking/python/keras/layers.py @@ -68,13 +68,13 @@ def create_tower(hidden_layer_dims: List[int], if input_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) for layer_width in hidden_layer_dims: - model.add(tf.keras.layers.Dense(units=layer_width), **kwargs) + model.add(tf.keras.layers.Dense(units=layer_width,**kwargs)) if use_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) model.add(tf.keras.layers.Activation(activation=activation)) if dropout: model.add(tf.keras.layers.Dropout(rate=dropout)) - model.add(tf.keras.layers.Dense(units=output_units), **kwargs) + model.add(tf.keras.layers.Dense(units=output_units, **kwargs)) return model From 9deef0fc1956538a1fc8b361ca4d4332a755d8c2 Mon Sep 17 00:00:00 2001 From: Zhihong Duan Date: Thu, 24 Jun 2021 16:34:24 -0700 Subject: [PATCH 2/2] format --- tensorflow_ranking/python/keras/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_ranking/python/keras/layers.py b/tensorflow_ranking/python/keras/layers.py index a740b65..5722228 100644 --- a/tensorflow_ranking/python/keras/layers.py +++ b/tensorflow_ranking/python/keras/layers.py @@ -68,7 +68,7 @@ def create_tower(hidden_layer_dims: List[int], if input_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) for layer_width in hidden_layer_dims: - model.add(tf.keras.layers.Dense(units=layer_width,**kwargs)) + model.add(tf.keras.layers.Dense(units=layer_width, **kwargs)) if use_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) model.add(tf.keras.layers.Activation(activation=activation))