Skip to content

Commit

Permalink
feat(torch): SWA for RANGER/torch (https://arxiv.org/abs/1803.05407)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and mergify[bot] committed Mar 26, 2021
1 parent efbd1f9 commit 74cf54c
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ adabelief | bool | yes | false for RANGER, true for RANGER_PLUS | f
gradient_centralization | bool | yes | false for RANGER, true for RANGER_PLUS| for RANGER* : enable/disable gradient centralization
sam | bool | yes | false | Sharpness Aware Minimization (https://arxiv.org/abs/2010.01412)
sam_rho | real | yes | 0.05 | neighborhood size for SAM (see above)
swa | bool | yes | false | SWA https://arxiv.org/abs/1803.05407 , implemented only for RANGER / RANGER_PLUS solver types.
test_interval | int | yes | N/A | Number of iterations between testing phases
base_lr | real | yes | N/A | Initial learning rate
iter_size | int | yes | 1 | Number of passes (iter_size * batch_size) at every iteration
Expand Down
43 changes: 41 additions & 2 deletions src/backends/torch/optim/ranger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ namespace dd
&& (lhs.lookahead() == rhs.lookahead())
&& (lhs.adabelief() == rhs.adabelief())
&& (lhs.gradient_centralization() == rhs.gradient_centralization())
&& (lhs.lsteps() == rhs.lsteps()) && (lhs.lalpha() == rhs.lalpha());
&& (lhs.lsteps() == rhs.lsteps()) && (lhs.lalpha() == rhs.lalpha())
&& (lhs.swa() == rhs.swa());
}

void RangerOptions::serialize(torch::serialize::OutputArchive &archive) const
Expand All @@ -68,6 +69,7 @@ namespace dd
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(gradient_centralization);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lsteps);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lalpha);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(swa);
}

void RangerOptions::serialize(torch::serialize::InputArchive &archive)
Expand All @@ -83,14 +85,16 @@ namespace dd
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, gradient_centralization);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int, lsteps);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lalpha);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, swa);
}

bool operator==(const RangerParamState &lhs, const RangerParamState &rhs)
{
return ((lhs.step() == rhs.step())
&& torch::equal(lhs.exp_avg(), rhs.exp_avg())
&& torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq())
&& torch::equal(lhs.slow_buffer(), rhs.slow_buffer()));
&& torch::equal(lhs.slow_buffer(), rhs.slow_buffer())
&& torch::equal(lhs.swa_buffer(), rhs.swa_buffer()));
}

void
Expand All @@ -100,6 +104,7 @@ namespace dd
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(slow_buffer);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(swa_buffer);
}

void RangerParamState::serialize(torch::serialize::InputArchive &archive)
Expand All @@ -108,12 +113,14 @@ namespace dd
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, exp_avg);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, exp_avg_sq);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, slow_buffer);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, swa_buffer);
}

torch::Tensor Ranger::step(LossClosure closure)
{
torch::NoGradGuard no_grad;
torch::Tensor loss = {};

if (closure != nullptr)
{
at::AutoGradMode enable_grad(true);
Expand Down Expand Up @@ -151,6 +158,9 @@ namespace dd
state->slow_buffer().copy_(p.data());
state_[c10::guts::to_string(p.unsafeGetTensorImpl())]
= std::move(state);
if (options.swa())
state->swa_buffer(torch::zeros_like(
p.data(), torch::MemoryFormat::Preserve));
}

auto &state = static_cast<RangerParamState &>(
Expand Down Expand Up @@ -227,11 +237,40 @@ namespace dd
slow_p.add_(p.data() - slow_p, options.lalpha());
p.data().copy_(slow_p);
}

if (options.swa())
{
auto &swa_buf = state.swa_buffer();
double swa_decay = 1.0 / (state.step() + 1);
torch::Tensor diff = (p.data() - swa_buf) * swa_decay;
swa_buf.add_(diff);
}
}
}
return loss;
}

void Ranger::swap_swa_sgd()
{
for (auto &group : param_groups_)
{
auto &options = static_cast<RangerOptions &>(group.options());
if (!options.swa())
continue;
for (auto &p : group.params())
{
auto &state = static_cast<RangerParamState &>(
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
auto &swa_buf = state.swa_buffer();

auto tmp = torch::empty_like(p.data());
tmp.copy_(p.data());
p.data().copy_(swa_buf);
swa_buf.copy_(tmp);
}
}
}

void Ranger::save(torch::serialize::OutputArchive &archive) const
{
serialize(*this, archive);
Expand Down
10 changes: 10 additions & 0 deletions src/backends/torch/optim/ranger.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef RANGER_H
#define RANGER_H

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include <torch/arg.h>
Expand Down Expand Up @@ -57,6 +60,7 @@ namespace dd
TORCH_ARG(bool, gradient_centralization) = false;
TORCH_ARG(int, lsteps) = 6;
TORCH_ARG(double, lalpha) = 0.5;
TORCH_ARG(bool, swa) = false;

public:
void serialize(torch::serialize::InputArchive &archive) override;
Expand All @@ -73,6 +77,7 @@ namespace dd
TORCH_ARG(torch::Tensor, exp_avg);
TORCH_ARG(torch::Tensor, exp_avg_sq);
TORCH_ARG(torch::Tensor, slow_buffer);
TORCH_ARG(torch::Tensor, swa_buffer);

public:
void serialize(torch::serialize::InputArchive &archive) override;
Expand Down Expand Up @@ -118,11 +123,16 @@ namespace dd
void save(torch::serialize::OutputArchive &archive) const override;
void load(torch::serialize::InputArchive &archive) override;

void swap_swa_sgd();

private:
template <typename Self, typename Archive>
static void serialize(Self &self, Archive &archive)
{
_TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Ranger);
}
bool swa_in_params = false;
};
} // namespace dd

#endif
11 changes: 9 additions & 2 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,13 @@ namespace dd
TMLModel>::snapshot(int64_t elapsed_it, TorchSolver &tsolver)
{
this->_logger->info("Saving checkpoint after {} iterations", elapsed_it);
// solver is allowed to modify net during eval()/train() => do this call
// before saving net itself
tsolver.eval();
this->_module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it));
// Save optimizer
tsolver.save(this->_mlmodel._repo + "/solver-" + std::to_string(elapsed_it)
+ ".pt");
tsolver.train();
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand Down Expand Up @@ -810,7 +813,9 @@ namespace dd
APIData meas_out;
this->_logger->info("Start test");
tstart = steady_clock::now();
tsolver.eval();
test(ad, inputc, eval_dataset, test_batch_size, meas_out);
tsolver.train();
last_test_time = duration_cast<milliseconds>(
steady_clock::now() - tstart)
.count();
Expand Down Expand Up @@ -891,7 +896,9 @@ namespace dd
}
}
if (!snapshotted)
snapshot(elapsed_it, tsolver);
{
snapshot(elapsed_it, tsolver);
}
}
++it;

Expand Down
7 changes: 7 additions & 0 deletions src/backends/torch/torchsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,16 @@ namespace dd
_sam = ad_solver.get("sam").get<bool>();
if (ad_solver.has("sam_rho"))
_sam_rho = ad_solver.get("sam_rho").get<double>();
if (ad_solver.has("swa"))
_swa = ad_solver.get("swa").get<bool>();
create();
}

void TorchSolver::create()
{

bool want_swa = true;
_swa = false;
this->_logger->info("Selected solver type: {}", _solver_type);

_params = _module.parameters();
Expand Down Expand Up @@ -107,6 +111,8 @@ namespace dd
}
else if (_solver_type == "RANGER" || _solver_type == "RANGER_PLUS")
{
if (want_swa)
_swa = true;
_optimizer = std::unique_ptr<torch::optim::Optimizer>(
new Ranger(_params, RangerOptions(_base_lr)
.betas(std::make_tuple(_beta1, _beta2))
Expand Down Expand Up @@ -251,6 +257,7 @@ namespace dd
try
{
torch::load(*_optimizer, sstate, device);
this->train();
}
catch (std::exception &e)
{
Expand Down
19 changes: 19 additions & 0 deletions src/backends/torch/torchsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "apidata.h"
#include "torchmodule.h"
#include "torchloss.h"
#include "optim/ranger.h"

#define DEFAULT_CLIP_VALUE 5.0
#define DEFAULT_CLIP_NORM 100.0
Expand Down Expand Up @@ -105,6 +106,16 @@ namespace dd
return _base_lr;
}

void eval()
{
swap_swa_sgd();
}

void train()
{
swap_swa_sgd();
}

protected:
/**
* \brief allocates solver for real
Expand All @@ -115,6 +126,12 @@ namespace dd
void sam_first_step();
void sam_second_step();

void swap_swa_sgd()
{
if (_swa)
(reinterpret_cast<Ranger *>(_optimizer.get()))->swap_swa_sgd();
}

std::vector<torch::Tensor> _sam_ew;

std::vector<at::Tensor> _params; /**< list of parameter to optimize,
Expand All @@ -141,6 +158,8 @@ namespace dd
bool _sam = false;
double _sam_rho = DEFAULT_SAM_RHO;

bool _swa = false; /**< stochastic weights averaging 1803.05407 */

TorchModule &_module;
TorchLoss &_tloss;
std::vector<torch::Device>
Expand Down

0 comments on commit 74cf54c

Please sign in to comment.