From 5152b435c2f548dbb83e5ebfa96b3641abb2a640 Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Sun, 25 Feb 2024 16:29:43 +0100 Subject: [PATCH] optim start --- src/models/param_scan.py | 80 +++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/src/models/param_scan.py b/src/models/param_scan.py index af77c6b..3cc9727 100644 --- a/src/models/param_scan.py +++ b/src/models/param_scan.py @@ -358,47 +358,51 @@ def main(): with open("configs/train_config.yaml", "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) - models = ["cnn", "rnn", "transformer"] + # models = ["cnn", "rnn", "transformer"] targets = ["regression", "classification", "multi_classification"] - for current_model in models: - model_type = current_model - for target_type in targets: - target = target_type - # model_type = "rnn" - # target = "classification" + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", type=str, default="cnn", help="Type of model to train." + ) + args = parser.parse_args() + model_type = args.model_type + print("model_type:", model_type) + + for target_type in targets: + target = target_type - data_path = config["data_path"] - data_mean = config["data_mean"] - data_std = config["data_std"] - n_train = config["n_train"] - n_val = config["n_val"] - n_test = config["n_test"] - batch_size = config["batch_size"] - buffer_size = config["buffer_size"] - epochs = config["epochs"] - optimizer = config["optimizer"] - loss = config["loss"] - learning_rate = config["learning_rate"] - model_config = config["model_config"] - - train( - data_path, - data_mean, - data_std, - n_train, - n_val, - n_test, - batch_size, - buffer_size, - epochs, - optimizer, - loss, - learning_rate, - model_config, - model_type, - target, - ) + data_path = config["data_path"] + data_mean = config["data_mean"] + data_std = config["data_std"] + n_train = config["n_train"] + n_val = config["n_val"] + n_test = config["n_test"] + batch_size = config["batch_size"] + buffer_size = config["buffer_size"] + epochs = config["epochs"] + optimizer = config["optimizer"] + loss = config["loss"] + learning_rate = config["learning_rate"] + model_config = config["model_config"] + + train( + data_path, + data_mean, + data_std, + n_train, + n_val, + n_test, + batch_size, + buffer_size, + epochs, + optimizer, + loss, + learning_rate, + model_config, + model_type, + target, + ) if __name__ == "__main__":