diff --git a/src/backends/torch/torchsolver.cc b/src/backends/torch/torchsolver.cc index 3f3e71fe7..b0a1df5dd 100644 --- a/src/backends/torch/torchsolver.cc +++ b/src/backends/torch/torchsolver.cc @@ -97,7 +97,14 @@ namespace dd _params, torch::optim::AdamOptions(_base_lr) .betas(std::make_tuple(_beta1, _beta2)) .weight_decay(_weight_decay))); - this->_logger->info("base_lr: {}", _base_lr); + } + else if (_solver_type == "ADAMW") + { + _optimizer + = std::unique_ptr(new torch::optim::AdamW( + _params, torch::optim::AdamWOptions(_base_lr) + .betas(std::make_tuple(_beta1, _beta2)) + .weight_decay(_weight_decay))); } else if (_solver_type == "RMSPROP") { @@ -105,7 +112,6 @@ namespace dd new torch::optim::RMSprop( _params, torch::optim::RMSpropOptions(_base_lr).weight_decay( _weight_decay))); - this->_logger->info("base_lr: {}", _base_lr); } else if (_solver_type == "ADAGRAD") { @@ -113,7 +119,6 @@ namespace dd new torch::optim::Adagrad( _params, torch::optim::AdagradOptions(_base_lr).weight_decay( _weight_decay))); - this->_logger->info("base_lr: {}", _base_lr); } else if (_solver_type == "RANGER" || _solver_type == "RANGER_PLUS") { @@ -131,7 +136,6 @@ namespace dd .adamp(_adamp) .lsteps(_lsteps) .lalpha(_lalpha))); - this->_logger->info("base_lr: {}", _base_lr); this->_logger->info("beta_1: {}", _beta1); this->_logger->info("beta_2: {}", _beta2); this->_logger->info("weight_decay: {}", _weight_decay); @@ -162,7 +166,6 @@ namespace dd .lookahead(_lookahead) .lsteps(_lsteps) .lalpha(_lalpha))); - this->_logger->info("base_lr: {}", _base_lr); this->_logger->info("momentum: {}", _momentum); this->_logger->info("weight_decay: {}", _weight_decay); this->_logger->info("lookahead: {}", _lookahead); @@ -180,7 +183,6 @@ namespace dd _optimizer = std::unique_ptr(new torch::optim::SGD( _params, torch::optim::SGDOptions(_base_lr))); - this->_logger->info("base_lr: {}", _base_lr); } this->_logger->info("clip: {}", _clip); if (_clip) @@ -199,6 +201,8 @@ namespace dd } if (_sam) this->_logger->info("using Sharpness Aware Minimization (SAM)"); + this->_logger->info("using optimizer " + _solver_type); + this->_logger->info("base_lr: {}", _base_lr); } void TorchSolver::sam_first_step() @@ -417,6 +421,14 @@ namespace dd options.betas(std::make_tuple(_beta1, _beta2)); options.weight_decay(_weight_decay); } + else if (_solver_type == "ADAMW") + { + auto &options = static_cast( + param_group.options()); + options.lr(_base_lr); + options.betas(std::make_tuple(_beta1, _beta2)); + options.weight_decay(_weight_decay); + } else if (_solver_type == "RMSPROP") { auto &options = static_cast(