diff --git a/src/__main__.py b/src/__main__.py index 764ac7c..c6a906b 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -4,6 +4,7 @@ import torch import pickle import numpy as np +import os from .documenter import Documenter from .train import DiscriminatorTraining @@ -45,17 +46,17 @@ def main(): print(f" Val points: {len(data.val_true)} truth, {len(data.val_fake)} generated") print(" Building model") - model_dir = doc.get_file(f"model_{data.suffix}.pth") + model_dir = doc.get_file(f"model_{data.suffix}") os.makedirs(model_dir, exist_ok=True) training = DiscriminatorTraining(params, device, data, model_dir) - if args.load_model: - print(" Loading model") - training.load(args.model_name) - else: + if not args.load_model: print(" Running training") training.train() + print(f" Loading model {args.model_name}") + training.load(args.model_name) + if args.load_model and args.load_weights: print(" Loading weights") with open(doc.get_file(f"weights_{data.suffix}.pkl"), "rb") as f: diff --git a/src/plots.py b/src/plots.py index bd46aa2..9aa91f7 100644 --- a/src/plots.py +++ b/src/plots.py @@ -88,7 +88,8 @@ def plot_losses(self, file: str): pdf, "learning rate", (self.losses["lr"], ), - (None, ) + (None, ), + "log" ) @@ -97,7 +98,8 @@ def plot_single_loss( pdf: PdfPages, ylabel: str, curves: tuple[np.ndarray], - labels: tuple[str] + labels: tuple[str], + yscale: str = "linear" ): """ Makes single loss plot. @@ -107,6 +109,7 @@ def plot_single_loss( ylabel: Y axis label curves: List of numpy arrays with the loss curves to be plotted labels: Labels of the loss curves + yscale: Y axis scale, "linear" or "log" """ fig, ax = plt.subplots(figsize=(4,3.5)) for i, (curve, label) in enumerate(zip(curves, labels)): @@ -122,6 +125,7 @@ def plot_single_loss( verticalalignment = "top", transform = ax.transAxes ) + ax.set_yscale(yscale) if any(label is not None for label in labels): ax.legend(loc="center right", frameon=False) plt.savefig(pdf, format="pdf", bbox_inches="tight") @@ -309,6 +313,7 @@ def plot_single_weight_hist( ax.set_xscale(xscale) ax.set_yscale(yscale) ax.set_xlim(bins[0], bins[-1]) + ax.legend(frameon=False) plt.savefig(pdf, format="pdf", bbox_inches="tight") plt.close() diff --git a/src/train.py b/src/train.py index e259b38..10354aa 100644 --- a/src/train.py +++ b/src/train.py @@ -184,14 +184,15 @@ def train(self): train_loss = torch.stack(epoch_losses).mean() self.losses["train_loss"].append(train_loss.item()) self.losses["val_loss"].append(val_loss.item()) - self.losses["lr"].append(self.optimizer.param_groups[0]["lr"]) + epoch_lr = self.optimizer.param_groups[0]["lr"] + self.losses["lr"].append(epoch_lr) if self.bayesian: self.losses["train_bce_loss"].append(torch.stack(epoch_bce_losses).mean().item()) self.losses["train_kl_loss"].append(torch.stack(epoch_kl_losses).mean().item()) self.losses["val_bce_loss"].append(val_bce_loss.item()) self.losses["val_kl_loss"].append(val_kl_loss.item()) print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " + - f"val loss {val_loss:.6f}") + f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True) if val_loss < best_val_loss: best_val_loss = val_loss