-
Notifications
You must be signed in to change notification settings - Fork 1
/
optimizers.py
36 lines (27 loc) · 915 Bytes
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import logging
from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop
logger = logging.getLogger("ptsemseg")
key2opt = {
"sgd": SGD,
"adam": Adam,
"asgd": ASGD,
"adamax": Adamax,
"adadelta": Adadelta,
"adagrad": Adagrad,
"rmsprop": RMSprop,
}
def get_optimizer(opt_dict, model_params):
optimizer = _get_optimizer_instance(opt_dict)
params = {k: v for k, v in opt_dict.items() if k != "name"}
optimizer = optimizer(model_params, **params)
return optimizer
def _get_optimizer_instance(opt_dict):
if opt_dict is None:
logger.info("Using SGD optimizer")
return SGD
else:
opt_name = opt_dict["name"]
if opt_name not in key2opt:
raise NotImplementedError("Optimizer {} not implemented".format(opt_name))
logger.info("Using {} optimizer".format(opt_name))
return key2opt[opt_name]