Skip to content

Commit

Permalink
various changes
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 16, 2023
1 parent 384497f commit 7df06c4
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Parameter | Description
### Training

Parameter | Description
------------------|--------------------------------------------------------
------------------|---------------------------------------------------------------------------
`bayesian` | Train as a Bayesian network
`batch_size` | Batch size
`lr` | Initial learning rate
Expand All @@ -59,6 +59,7 @@ Parameter | Description
`lr_decay_epochs` | Step scheduler: Epochs after which to reduce the LR
`lr_decay_factor` | Step scheduler: Decay factor
`epochs` | Number of epochs
`train_samples` | Total number of samples used for training (alternative to number of epochs)

### Evaluation

Expand Down
2 changes: 1 addition & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, input_dim: int, params: dict):
self.bayesian_layers.append(layer)
layers.append(layer)
if dropout > 0:
self.layers.append(nn.Dropout(p=dropout))
layers.append(nn.Dropout(p=dropout))
layers.append(activation())
layer_size = hidden_size
layer = layer_class(layer_size, 1, **layer_kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,23 +276,23 @@ def plot_single_weight_hist(
ax,
bins,
y_combined / np.sum(y_combined),
y_combined_err / np.sum(y_combined),
y_combined_err / np.sum(y_combined) if y_combined_err is not None else None,
label = "Comb",
color = self.colors[0]
)
self.hist_line(
ax,
bins,
y_true / np.sum(y_true),
y_true_err / np.sum(y_true),
y_true_err / np.sum(y_true) if y_true_err is not None else None,
label = "Truth",
color = self.colors[1]
)
self.hist_line(
ax,
bins,
y_fake / np.sum(y_fake),
y_fake_err / np.sum(y_fake),
y_fake_err / np.sum(y_fake) if y_fake_err is not None else None,
label = "Gen",
color = self.colors[2]
)
Expand Down
10 changes: 8 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def init_scheduler(self):
"""
Initialized the LR scheduler. Currently, one-cycle and step schedulers are supported.
"""
if "epochs" in self.params:
self.epochs = self.params["epochs"]
else:
self.epochs = int(
self.params["train_samples"] / self.params["batch_size"] / self.train_batches
)
self.scheduler_type = self.params.get("lr_scheduler", "one_cycle")
if self.scheduler_type == "step":
self.scheduler = torch.optim.lr_scheduler.StepLR(
Expand All @@ -105,7 +111,7 @@ def init_scheduler(self):
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.optimizer,
self.params.get("max_lr", self.params["lr"]*10),
epochs = self.params["epochs"],
epochs = self.epochs,
steps_per_epoch=self.train_batches
)
else:
Expand Down Expand Up @@ -144,7 +150,7 @@ def train(self):
"""
Main training loop
"""
for epoch in range(self.params["epochs"]):
for epoch in range(self.epochs):
self.model.train()
epoch_losses, epoch_bce_losses, epoch_kl_losses = [], [], []
for (x_true, ), (x_fake, ) in zip(self.train_loader_true, self.train_loader_fake):
Expand Down

0 comments on commit 7df06c4

Please sign in to comment.