From c4aa1b3fec43bbd7836eda48ad29a261ecc883d9 Mon Sep 17 00:00:00 2001 From: Theo Heimel Date: Thu, 16 Mar 2023 15:07:18 +0100 Subject: [PATCH] added reduce on pleateau scheduler --- README.md | 5 +++-- src/train.py | 10 +++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index aff00d4..b6d3573 100644 --- a/README.md +++ b/README.md @@ -59,10 +59,11 @@ Parameter | Description `betas` | Adam optimizer betas `eps` | Adam optimizer eps `weight_decay` | L2 weight decay -`lr_scheduler` | Type of LR scheduler: `one_cycle` or `step` +`lr_scheduler` | Type of LR scheduler: `one_cycle`, `step`, `reduce_on_plateau` `max_lr` | One Cycle scheduler: maximum LR `lr_decay_epochs` | Step scheduler: Epochs after which to reduce the LR -`lr_decay_factor` | Step scheduler: Decay factor +`lr_decay_factor` | Step and reduce on plateau schedulers: Decay factor +`lr_patience` | Reduce on plateau scheduler: Number of epochs without improvement for reduction. `epochs` | Number of epochs `train_samples` | Total number of samples used for training (alternative to number of epochs) `checkpoint_interval` | If value n set, save the model after every n epochs diff --git a/src/train.py b/src/train.py index db78bef..e259b38 100644 --- a/src/train.py +++ b/src/train.py @@ -118,6 +118,12 @@ def init_scheduler(self): epochs = self.epochs, steps_per_epoch=self.train_batches ) + elif self.scheduler_type == "reduce_on_plateau": + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + factor = self.params["lr_decay_factor"], + patience = self.params["lr_patience"] + ) else: raise ValueError(f"Unknown LR scheduler '{self.scheduler_type}'") @@ -169,10 +175,12 @@ def train(self): self.optimizer.step() if self.scheduler_type == "one_cycle": self.scheduler.step() + val_loss, val_bce_loss, val_kl_loss = self.val_loss() if self.scheduler_type == "step": self.scheduler.step() + elif self.scheduler_type == "reduce_on_plateau": + self.scheduler.step(val_loss) - val_loss, val_bce_loss, val_kl_loss = self.val_loss() train_loss = torch.stack(epoch_losses).mean() self.losses["train_loss"].append(train_loss.item()) self.losses["val_loss"].append(val_loss.item())