Skip to content

Commit

Permalink
optim start
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed Feb 25, 2024
1 parent 7a090e4 commit 5152b43
Showing 1 changed file with 42 additions and 38 deletions.
80 changes: 42 additions & 38 deletions src/models/param_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,47 +358,51 @@ def main():
with open("configs/train_config.yaml", "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

models = ["cnn", "rnn", "transformer"]
# models = ["cnn", "rnn", "transformer"]
targets = ["regression", "classification", "multi_classification"]

for current_model in models:
model_type = current_model
for target_type in targets:
target = target_type
# model_type = "rnn"
# target = "classification"
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type", type=str, default="cnn", help="Type of model to train."
)
args = parser.parse_args()
model_type = args.model_type
print("model_type:", model_type)

for target_type in targets:
target = target_type

data_path = config["data_path"]
data_mean = config["data_mean"]
data_std = config["data_std"]
n_train = config["n_train"]
n_val = config["n_val"]
n_test = config["n_test"]
batch_size = config["batch_size"]
buffer_size = config["buffer_size"]
epochs = config["epochs"]
optimizer = config["optimizer"]
loss = config["loss"]
learning_rate = config["learning_rate"]
model_config = config["model_config"]

train(
data_path,
data_mean,
data_std,
n_train,
n_val,
n_test,
batch_size,
buffer_size,
epochs,
optimizer,
loss,
learning_rate,
model_config,
model_type,
target,
)
data_path = config["data_path"]
data_mean = config["data_mean"]
data_std = config["data_std"]
n_train = config["n_train"]
n_val = config["n_val"]
n_test = config["n_test"]
batch_size = config["batch_size"]
buffer_size = config["buffer_size"]
epochs = config["epochs"]
optimizer = config["optimizer"]
loss = config["loss"]
learning_rate = config["learning_rate"]
model_config = config["model_config"]

train(
data_path,
data_mean,
data_std,
n_train,
n_val,
n_test,
batch_size,
buffer_size,
epochs,
optimizer,
loss,
learning_rate,
model_config,
model_type,
target,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 5152b43

Please sign in to comment.