speech_asr
2023-04-20 993fdd8ecf50e9260c2885c273a279186a68d1f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
 
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
 
 
def build_optimizer(args, model):
    optim_classes = dict(
        adam=torch.optim.Adam,
        fairseq_adam=FairseqAdam,
        adamw=torch.optim.AdamW,
        sgd=SGD,
        adadelta=torch.optim.Adadelta,
        adagrad=torch.optim.Adagrad,
        adamax=torch.optim.Adamax,
        asgd=torch.optim.ASGD,
        lbfgs=torch.optim.LBFGS,
        rmsprop=torch.optim.RMSprop,
        rprop=torch.optim.Rprop,
    )
 
    optim_class = optim_classes.get(args.optim)
    if optim_class is None:
        raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
    optimizer = optim_class(model.parameters(), **args.optim_conf)
    return optimizer