From 28e79017f6d7f9804bce478c0b628c251ddfe7c4 Mon Sep 17 00:00:00 2001 From: Theo Heimel Date: Thu, 16 Mar 2023 14:54:11 +0100 Subject: [PATCH] save best, final, checkpoint models --- README.md | 6 ++++++ src/__main__.py | 12 ++++++------ src/train.py | 27 ++++++++++++++++++++++----- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index a7a76b4..aff00d4 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,11 @@ Loading a trained model: python -m src --load_model 20230303_100000_run_name ``` +Loading a specific trained model (e.g. best model, final model, checkpoint): +``` +python -m src --load_model --model_name=final 20230303_100000_run_name +``` + Loading a trained model, but compute weights again: ``` python -m src --load_model --load_weights 20230303_100000_run_name @@ -60,6 +65,7 @@ Parameter | Description `lr_decay_factor` | Step scheduler: Decay factor `epochs` | Number of epochs `train_samples` | Total number of samples used for training (alternative to number of epochs) +`checkpoint_interval` | If value n set, save the model after every n epochs ### Evaluation diff --git a/src/__main__.py b/src/__main__.py index d39acd7..764ac7c 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -20,6 +20,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("paramfile") parser.add_argument("--load_model", action="store_true") + parser.add_argument("--model_name", type=str, default="best") parser.add_argument("--load_weights", action="store_true") args = parser.parse_args() @@ -44,17 +45,16 @@ def main(): print(f" Val points: {len(data.val_true)} truth, {len(data.val_fake)} generated") print(" Building model") - training = DiscriminatorTraining(params, device, data) + model_dir = doc.get_file(f"model_{data.suffix}.pth") + os.makedirs(model_dir, exist_ok=True) + training = DiscriminatorTraining(params, device, data, model_dir) if args.load_model: print(" Loading model") - training.load(doc.get_file(f"model_{data.suffix}.pth")) - - if not args.load_model: + training.load(args.model_name) + else: print(" Running training") training.train() - print(" Saving model") - training.save(doc.get_file(f"model_{data.suffix}.pth")) if args.load_model and args.load_weights: print(" Loading weights") diff --git a/src/train.py b/src/train.py index d9af34b..db78bef 100644 --- a/src/train.py +++ b/src/train.py @@ -1,3 +1,4 @@ +import os import torch import torch.nn as nn import numpy as np @@ -15,7 +16,8 @@ def __init__( self, params: dict, device: torch.device, - data: DiscriminatorData + data: DiscriminatorData, + model_path: str ): """ Build the network and data loaders. @@ -24,10 +26,12 @@ def __init__( params: Dict with architecture and training hyperparameters device: Pytorch device used for training and evaluation data: DiscriminatorData object containing the training and evaluation data + model_path: Path to a directory where models are saved """ self.params = params self.device = device self.data = data + self.model_path = model_path self.init_data_loaders() self.model = Discriminator(data.dim, params) @@ -150,6 +154,8 @@ def train(self): """ Main training loop """ + best_val_loss = 1e20 + checkpoint_interval = self.params.get("checkpoint_interval") for epoch in range(self.epochs): self.model.train() epoch_losses, epoch_bce_losses, epoch_kl_losses = [], [], [] @@ -179,6 +185,15 @@ def train(self): print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " + f"val loss {val_loss:.6f}") + if val_loss < best_val_loss: + best_val_loss = val_loss + self.save("best") + + if checkpoint_interval is not None and (epoch+1) % checkpoint_interval == 0: + self.save(f"epoch_{epoch}") + + self.save("final") + def val_loss(self) -> tuple[torch.Tensor, ...]: """ @@ -267,13 +282,14 @@ def predict_single(self): return w_true.cpu().numpy(), w_fake.cpu().numpy(), clf_score.cpu().numpy() - def save(self, file: str): + def save(self, name: str): """ Saves the model, optimizer and losses. Args: - file: Output file name + name: File name for the model (without path and extension) """ + file = os.path.join(self.model_path, f"{name}.pth") torch.save({ "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), @@ -281,13 +297,14 @@ def save(self, file: str): }, file) - def load(self, file: str): + def load(self, name: str): """ Loads the model, optimizer and losses. Args: - file: Input file name + name: File name for the model (without path and extension) """ + file = os.path.join(self.model_path, f"{name}.pth") state_dicts = torch.load(file, map_location=self.device) self.optimizer.load_state_dict(state_dicts["optimizer"]) self.model.load_state_dict(state_dicts["model"])