-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizer.py
295 lines (259 loc) · 11.8 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import torch
import math
import warnings
from typing import List
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer, required
import re
EETA_DEFAULT = 0.001
#### LARS optimizer ####
class LARS(Optimizer):
"""
Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
"""
def __init__(
self,
params,
lr=required,
momentum=0.9,
use_nesterov=False,
weight_decay=0.0,
exclude_from_weight_decay=None,
exclude_from_layer_adaptation=None,
classic_momentum=True,
eeta=EETA_DEFAULT
):
"""Constructs a LARSOptimizer.
Args:
param_names: names of parameters of model obtained by
[name for name, p in model.named_parameters() if p.requires_grad]
lr: A `float` for learning rate.
momentum: A `float` for momentum.
use_nesterov: A 'Boolean' for whether to use nesterov momentum.
weight_decay: A `float` for weight decay.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['bn', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
classic_momentum: A `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momeuntum update in
classic momentum, but after momentum for popular momentum.
eeta: A `float` for scaling of learning rate when computing trust ratio.
name: The name for the scope.
"""
self.epoch = 0
defaults = dict(
lr=lr,
momentum=momentum,
use_nesterov=use_nesterov,
weight_decay=weight_decay,
exclude_from_weight_decay=exclude_from_weight_decay,
exclude_from_layer_adaptation=exclude_from_layer_adaptation,
classic_momentum=classic_momentum,
eeta=eeta
)
super(LARS, self).__init__(params, defaults)
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.use_nesterov = use_nesterov
self.classic_momentum = classic_momentum
self.eeta = eeta
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay
self.param_name_map = {'batch_normalization':'bn','bias':'bias'}
def step(self, epoch=None, closure=None):
loss = None
if closure is not None:
loss = closure()
if epoch is None:
epoch = self.epoch
self.epoch += 1
for group in self.param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
eeta = group["eeta"]
lr = group["lr"]
#print(lr)
#param_names = group["param_names"]
for p_name, p in zip(group["param_names"],group["params"]):
if p.grad is None:
continue
param = p.data
grad = p.grad.data
param_state = self.state[p]
# TODO: get param names
if self._use_weight_decay(p_name):
grad += self.weight_decay * param
#else:
# print(p_name)
if self.classic_momentum:
trust_ratio = 1.0
# TODO: get param names
if self._do_layer_adaptation(p_name):
w_norm = torch.norm(param)
g_norm = torch.norm(grad)
device = g_norm.get_device()
trust_ratio = torch.where(
w_norm.gt(0),
torch.where(
g_norm.gt(0),
(self.eeta * w_norm / g_norm),
torch.Tensor([1.0]).to(device),
),
torch.Tensor([1.0]).to(device),
).item()
scaled_lr = lr * trust_ratio
if "momentum_buffer" not in param_state:
next_v = param_state["momentum_buffer"] = torch.zeros_like(
p.data
)
else:
next_v = param_state["momentum_buffer"]
next_v.mul_(momentum).add_(grad, alpha = scaled_lr)
if self.use_nesterov:
update = (self.momentum * next_v) + (scaled_lr * grad)
else:
update = next_v
p.data.add_(-update)
else:
trust_ratio = 1.0
if "momentum_buffer" not in param_state:
next_v = param_state["momentum_buffer"] = torch.zeros_like(p.data)
else:
next_v = param_state["momentum_buffer"]
next_v.mul_(momentum).add_(grad)
if self.use_nesterov:
update = (self.momentum * next_v) + grad
else:
update = next_v
if self._do_layer_adaptation(p_name):
w_norm = torch.norm(param)
g_norm = torch.norm(grad)
device = g_norm.get_device()
trust_ratio = torch.where(
w_norm.gt(0),
torch.where(
g_norm.gt(0),
(self.eeta * w_norm / g_norm),
torch.Tensor([1.0]).to(device),
),
torch.Tensor([1.0]).to(device),
).item()
scaled_lr = lr * trust_ratio
p.data.add_(-update, alpha = scaled_lr)
return loss
def _use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(self.param_name_map[r], param_name) is not None:
return False
return True
def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(self.param_name_map[r], param_name) is not None:
return False
return True
#### cosine annealing LR scheduler ####
class LinearWarmupCosineAnnealingLR(_LRScheduler):
"""
Sets the learning rate of each parameter group to follow a linear warmup schedule
between warmup_start_lr and base_lr followed by a cosine annealing schedule between
base_lr and eta_min.
.. warning::
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
after each iteration as calling it after each epoch will keep the starting lr at
warmup_start_lr for the first epoch which is 0 in most cases.
.. warning::
passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
:func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
train and validation methods.
"""
def __init__(
self,
optimizer: Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.0,
eta_min: float = 0.0,
last_epoch: int = -1,
verbose = True
) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_epochs (int): Maximum number of iterations for linear warmup
max_epochs (int): Maximum number of iterations
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
"""
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)
def get_lr(self) -> List[float]:
"""
Compute learning rate using chainable form of the scheduler
"""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
)
if self.last_epoch == 0:
return [self.warmup_start_lr] * len(self.base_lrs)
elif self.last_epoch < self.warmup_epochs:
return [
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
elif self.last_epoch == self.warmup_epochs:
return self.base_lrs
elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
return [
group["lr"] + (base_lr - self.eta_min) *
(1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) /
(
1 +
math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs))
) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self) -> List[float]:
"""
Called when epoch is passed as a param to the `step` function of the scheduler.
"""
if self.last_epoch < self.warmup_epochs:
return [
self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr in self.base_lrs
]
return [
self.eta_min + 0.5 * (base_lr - self.eta_min) *
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
for base_lr in self.base_lrs
]