Skip to content

Commit

Permalink
added reduce on pleateau scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 16, 2023
1 parent 28e7901 commit c4aa1b3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ Parameter | Description
`betas` | Adam optimizer betas
`eps` | Adam optimizer eps
`weight_decay` | L2 weight decay
`lr_scheduler` | Type of LR scheduler: `one_cycle` or `step`
`lr_scheduler` | Type of LR scheduler: `one_cycle`, `step`, `reduce_on_plateau`
`max_lr` | One Cycle scheduler: maximum LR
`lr_decay_epochs` | Step scheduler: Epochs after which to reduce the LR
`lr_decay_factor` | Step scheduler: Decay factor
`lr_decay_factor` | Step and reduce on plateau schedulers: Decay factor
`lr_patience` | Reduce on plateau scheduler: Number of epochs without improvement for reduction.
`epochs` | Number of epochs
`train_samples` | Total number of samples used for training (alternative to number of epochs)
`checkpoint_interval` | If value n set, save the model after every n epochs
Expand Down
10 changes: 9 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ def init_scheduler(self):
epochs = self.epochs,
steps_per_epoch=self.train_batches
)
elif self.scheduler_type == "reduce_on_plateau":
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
factor = self.params["lr_decay_factor"],
patience = self.params["lr_patience"]
)
else:
raise ValueError(f"Unknown LR scheduler '{self.scheduler_type}'")

Expand Down Expand Up @@ -169,10 +175,12 @@ def train(self):
self.optimizer.step()
if self.scheduler_type == "one_cycle":
self.scheduler.step()
val_loss, val_bce_loss, val_kl_loss = self.val_loss()
if self.scheduler_type == "step":
self.scheduler.step()
elif self.scheduler_type == "reduce_on_plateau":
self.scheduler.step(val_loss)

val_loss, val_bce_loss, val_kl_loss = self.val_loss()
train_loss = torch.stack(epoch_losses).mean()
self.losses["train_loss"].append(train_loss.item())
self.losses["val_loss"].append(val_loss.item())
Expand Down

0 comments on commit c4aa1b3

Please sign in to comment.