Skip to content

Commit

Permalink
added activations
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed May 21, 2024
1 parent dbc326c commit b46c254
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 7 additions & 1 deletion configs/new_train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ model_config:
n_dense_layers: 2
dense_size: 32
output_shape: 6
output_activation: null
rnn_config:
n_rnn_layers: 5
rnn_units: 512
n_dense_layers: 6
dense_size: 512
output_shape: 6
input_shape: [7, 1]
output_activation: null
attn_config:
input_shape: [7, 1]
output_shape: 6
Expand All @@ -54,13 +56,15 @@ model_config:
n_dense_layers: 1
dense_size: 32
output_shape: 1
output_activation: "sigmoid"
rnn_config:
n_rnn_layers: 3
rnn_units: 416
n_dense_layers: 6
dense_size: 192
output_shape: 1
input_shape: [7, 1]
output_activation: "sigmoid"
attn_config:
input_shape: [7, 1]
output_shape: 1
Expand All @@ -81,10 +85,12 @@ model_config:
n_dense_layers: 4
dense_size: 128
output_shape: 3
output_activation: "softmax"
rnn_config:
n_rnn_layers: 4
rnn_units: 512
n_dense_layers: 2
dense_size: 512
output_shape: 3
input_shape: [7, 1]
input_shape: [7, 1]
output_activation: "softmax"
4 changes: 2 additions & 2 deletions src/models/new_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def build_model(model_config, model_name):
)

# Add output layer
model.add(Dense(cnn_config["output_shape"]))
model.add(Dense(cnn_config["output_shape"], activation=cnn_config["output_activation"]))
print("CNN model built:", "\n")
model.summary()

Expand All @@ -161,7 +161,7 @@ def build_model(model_config, model_name):
model.add(Dense(rnn_config["dense_size"], activation="relu"))

# Add output layer
model.add(Dense(rnn_config["output_shape"]))
model.add(Dense(rnn_config["output_shape"], activation=rnn_config["output_activation"]))
print("RNN model done")
model.summary()
elif model_name == "ar_rnn":
Expand Down

0 comments on commit b46c254

Please sign in to comment.