From b46c2545dd67c3080e927d25cae59d8800afd33f Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Tue, 21 May 2024 17:09:32 +0200 Subject: [PATCH] added activations --- configs/new_train_config.yaml | 8 +++++++- src/models/new_train_model.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/configs/new_train_config.yaml b/configs/new_train_config.yaml index f958b16..6b85334 100644 --- a/configs/new_train_config.yaml +++ b/configs/new_train_config.yaml @@ -27,6 +27,7 @@ model_config: n_dense_layers: 2 dense_size: 32 output_shape: 6 + output_activation: null rnn_config: n_rnn_layers: 5 rnn_units: 512 @@ -34,6 +35,7 @@ model_config: dense_size: 512 output_shape: 6 input_shape: [7, 1] + output_activation: null attn_config: input_shape: [7, 1] output_shape: 6 @@ -54,6 +56,7 @@ model_config: n_dense_layers: 1 dense_size: 32 output_shape: 1 + output_activation: "sigmoid" rnn_config: n_rnn_layers: 3 rnn_units: 416 @@ -61,6 +64,7 @@ model_config: dense_size: 192 output_shape: 1 input_shape: [7, 1] + output_activation: "sigmoid" attn_config: input_shape: [7, 1] output_shape: 1 @@ -81,10 +85,12 @@ model_config: n_dense_layers: 4 dense_size: 128 output_shape: 3 + output_activation: "softmax" rnn_config: n_rnn_layers: 4 rnn_units: 512 n_dense_layers: 2 dense_size: 512 output_shape: 3 - input_shape: [7, 1] \ No newline at end of file + input_shape: [7, 1] + output_activation: "softmax" \ No newline at end of file diff --git a/src/models/new_train_model.py b/src/models/new_train_model.py index a41d894..69821a9 100644 --- a/src/models/new_train_model.py +++ b/src/models/new_train_model.py @@ -137,7 +137,7 @@ def build_model(model_config, model_name): ) # Add output layer - model.add(Dense(cnn_config["output_shape"])) + model.add(Dense(cnn_config["output_shape"], activation=cnn_config["output_activation"])) print("CNN model built:", "\n") model.summary() @@ -161,7 +161,7 @@ def build_model(model_config, model_name): model.add(Dense(rnn_config["dense_size"], activation="relu")) # Add output layer - model.add(Dense(rnn_config["output_shape"])) + model.add(Dense(rnn_config["output_shape"], activation=rnn_config["output_activation"])) print("RNN model done") model.summary() elif model_name == "ar_rnn":