Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 16, 2023
1 parent c4aa1b3 commit 48db8f6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
11 changes: 6 additions & 5 deletions src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import pickle
import numpy as np
import os

from .documenter import Documenter
from .train import DiscriminatorTraining
Expand Down Expand Up @@ -45,17 +46,17 @@ def main():
print(f" Val points: {len(data.val_true)} truth, {len(data.val_fake)} generated")

print(" Building model")
model_dir = doc.get_file(f"model_{data.suffix}.pth")
model_dir = doc.get_file(f"model_{data.suffix}")
os.makedirs(model_dir, exist_ok=True)
training = DiscriminatorTraining(params, device, data, model_dir)

if args.load_model:
print(" Loading model")
training.load(args.model_name)
else:
if not args.load_model:
print(" Running training")
training.train()

print(f" Loading model {args.model_name}")
training.load(args.model_name)

if args.load_model and args.load_weights:
print(" Loading weights")
with open(doc.get_file(f"weights_{data.suffix}.pkl"), "rb") as f:
Expand Down
9 changes: 7 additions & 2 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def plot_losses(self, file: str):
pdf,
"learning rate",
(self.losses["lr"], ),
(None, )
(None, ),
"log"
)


Expand All @@ -97,7 +98,8 @@ def plot_single_loss(
pdf: PdfPages,
ylabel: str,
curves: tuple[np.ndarray],
labels: tuple[str]
labels: tuple[str],
yscale: str = "linear"
):
"""
Makes single loss plot.
Expand All @@ -107,6 +109,7 @@ def plot_single_loss(
ylabel: Y axis label
curves: List of numpy arrays with the loss curves to be plotted
labels: Labels of the loss curves
yscale: Y axis scale, "linear" or "log"
"""
fig, ax = plt.subplots(figsize=(4,3.5))
for i, (curve, label) in enumerate(zip(curves, labels)):
Expand All @@ -122,6 +125,7 @@ def plot_single_loss(
verticalalignment = "top",
transform = ax.transAxes
)
ax.set_yscale(yscale)
if any(label is not None for label in labels):
ax.legend(loc="center right", frameon=False)
plt.savefig(pdf, format="pdf", bbox_inches="tight")
Expand Down Expand Up @@ -309,6 +313,7 @@ def plot_single_weight_hist(
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set_xlim(bins[0], bins[-1])
ax.legend(frameon=False)
plt.savefig(pdf, format="pdf", bbox_inches="tight")
plt.close()

Expand Down
5 changes: 3 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,15 @@ def train(self):
train_loss = torch.stack(epoch_losses).mean()
self.losses["train_loss"].append(train_loss.item())
self.losses["val_loss"].append(val_loss.item())
self.losses["lr"].append(self.optimizer.param_groups[0]["lr"])
epoch_lr = self.optimizer.param_groups[0]["lr"]
self.losses["lr"].append(epoch_lr)
if self.bayesian:
self.losses["train_bce_loss"].append(torch.stack(epoch_bce_losses).mean().item())
self.losses["train_kl_loss"].append(torch.stack(epoch_kl_losses).mean().item())
self.losses["val_bce_loss"].append(val_bce_loss.item())
self.losses["val_kl_loss"].append(val_kl_loss.item())
print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " +
f"val loss {val_loss:.6f}")
f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True)

if val_loss < best_val_loss:
best_val_loss = val_loss
Expand Down

0 comments on commit 48db8f6

Please sign in to comment.