1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| import torch
|
| from funasr.optimizers.fairseq_adam import FairseqAdam
| from funasr.optimizers.sgd import SGD
|
|
| def build_optimizer(args, model):
| optim_classes = dict(
| adam=torch.optim.Adam,
| fairseq_adam=FairseqAdam,
| adamw=torch.optim.AdamW,
| sgd=SGD,
| adadelta=torch.optim.Adadelta,
| adagrad=torch.optim.Adagrad,
| adamax=torch.optim.Adamax,
| asgd=torch.optim.ASGD,
| lbfgs=torch.optim.LBFGS,
| rmsprop=torch.optim.RMSprop,
| rprop=torch.optim.Rprop,
| )
|
| optim_class = optim_classes.get(args.optim)
| if optim_class is None:
| raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
| optimizer = optim_class(model.parameters(), **args.optim_conf)
|
| optimizers = [optimizer]
| return optimizers
|
|