forked from swyoon/normalized-autoencoders
-
Notifications
You must be signed in to change notification settings - Fork 2
/
optimizers.py
41 lines (29 loc) · 986 Bytes
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import logging
from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop
logger = logging.getLogger("ptsemseg")
key2opt = {
"sgd": SGD,
"adam": Adam,
"asgd": ASGD,
"adamax": Adamax,
"adadelta": Adadelta,
"adagrad": Adagrad,
"rmsprop": RMSprop,
}
def get_optimizer(opt_dict, model_params):
opt_dict = opt_dict.copy()
optimizer = _get_optimizer_instance(opt_dict)
# params = {k: v for k, v in opt_dict.items() if k != "name"}
opt_dict.pop('name')
optimizer = optimizer(model_params, **opt_dict)
return optimizer, None
def _get_optimizer_instance(opt_dict):
if opt_dict is None:
logger.info("Using SGD optimizer")
return SGD
else:
opt_name = opt_dict["name"]
if opt_name not in key2opt:
raise NotImplementedError("Optimizer {} not implemented".format(opt_name))
logger.info("Using {} optimizer".format(opt_name))
return key2opt[opt_name]