Skip to content

Commit

Permalink
save best, final, checkpoint models
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 16, 2023
1 parent 7df06c4 commit 28e7901
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ Loading a trained model:
python -m src --load_model 20230303_100000_run_name
```

Loading a specific trained model (e.g. best model, final model, checkpoint):
```
python -m src --load_model --model_name=final 20230303_100000_run_name
```

Loading a trained model, but compute weights again:
```
python -m src --load_model --load_weights 20230303_100000_run_name
Expand Down Expand Up @@ -60,6 +65,7 @@ Parameter | Description
`lr_decay_factor` | Step scheduler: Decay factor
`epochs` | Number of epochs
`train_samples` | Total number of samples used for training (alternative to number of epochs)
`checkpoint_interval` | If value n set, save the model after every n epochs

### Evaluation

Expand Down
12 changes: 6 additions & 6 deletions src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("paramfile")
parser.add_argument("--load_model", action="store_true")
parser.add_argument("--model_name", type=str, default="best")
parser.add_argument("--load_weights", action="store_true")
args = parser.parse_args()

Expand All @@ -44,17 +45,16 @@ def main():
print(f" Val points: {len(data.val_true)} truth, {len(data.val_fake)} generated")

print(" Building model")
training = DiscriminatorTraining(params, device, data)
model_dir = doc.get_file(f"model_{data.suffix}.pth")
os.makedirs(model_dir, exist_ok=True)
training = DiscriminatorTraining(params, device, data, model_dir)

if args.load_model:
print(" Loading model")
training.load(doc.get_file(f"model_{data.suffix}.pth"))

if not args.load_model:
training.load(args.model_name)
else:
print(" Running training")
training.train()
print(" Saving model")
training.save(doc.get_file(f"model_{data.suffix}.pth"))

if args.load_model and args.load_weights:
print(" Loading weights")
Expand Down
27 changes: 22 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
import torch.nn as nn
import numpy as np
Expand All @@ -15,7 +16,8 @@ def __init__(
self,
params: dict,
device: torch.device,
data: DiscriminatorData
data: DiscriminatorData,
model_path: str
):
"""
Build the network and data loaders.
Expand All @@ -24,10 +26,12 @@ def __init__(
params: Dict with architecture and training hyperparameters
device: Pytorch device used for training and evaluation
data: DiscriminatorData object containing the training and evaluation data
model_path: Path to a directory where models are saved
"""
self.params = params
self.device = device
self.data = data
self.model_path = model_path

self.init_data_loaders()
self.model = Discriminator(data.dim, params)
Expand Down Expand Up @@ -150,6 +154,8 @@ def train(self):
"""
Main training loop
"""
best_val_loss = 1e20
checkpoint_interval = self.params.get("checkpoint_interval")
for epoch in range(self.epochs):
self.model.train()
epoch_losses, epoch_bce_losses, epoch_kl_losses = [], [], []
Expand Down Expand Up @@ -179,6 +185,15 @@ def train(self):
print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " +
f"val loss {val_loss:.6f}")

if val_loss < best_val_loss:
best_val_loss = val_loss
self.save("best")

if checkpoint_interval is not None and (epoch+1) % checkpoint_interval == 0:
self.save(f"epoch_{epoch}")

self.save("final")


def val_loss(self) -> tuple[torch.Tensor, ...]:
"""
Expand Down Expand Up @@ -267,27 +282,29 @@ def predict_single(self):
return w_true.cpu().numpy(), w_fake.cpu().numpy(), clf_score.cpu().numpy()


def save(self, file: str):
def save(self, name: str):
"""
Saves the model, optimizer and losses.
Args:
file: Output file name
name: File name for the model (without path and extension)
"""
file = os.path.join(self.model_path, f"{name}.pth")
torch.save({
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"losses": self.losses
}, file)


def load(self, file: str):
def load(self, name: str):
"""
Loads the model, optimizer and losses.
Args:
file: Input file name
name: File name for the model (without path and extension)
"""
file = os.path.join(self.model_path, f"{name}.pth")
state_dicts = torch.load(file, map_location=self.device)
self.optimizer.load_state_dict(state_dicts["optimizer"])
self.model.load_state_dict(state_dicts["model"])
Expand Down

0 comments on commit 28e7901

Please sign in to comment.