diff --git a/modules.py b/modules.py index 3d2c260..b9e6c7c 100644 --- a/modules.py +++ b/modules.py @@ -349,9 +349,6 @@ def forward(self, h_t): # bound between [-1, 1] l_t = F.tanh(l_t) - # prevent gradient flow - l_t = l_t.detach() - return mu, l_t diff --git a/trainer.py b/trainer.py index 4d44de7..dfad1d8 100644 --- a/trainer.py +++ b/trainer.py @@ -170,8 +170,8 @@ def train(self): # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) - # reduce lr if validation loss plateaus - self.scheduler.step(valid_loss) + # # reduce lr if validation loss plateaus + # self.scheduler.step(valid_loss) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " @@ -293,7 +293,7 @@ def train_one_epoch(self, epoch): ( "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc-tic), loss.data[0], acc.data[0] - ) + ) ) ) pbar.update(self.batch_size)