| New file |
| | |
| | | import torch |
| | | |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.optimizers.sgd import SGD |
| | | |
| | | |
| | | def build_optimizer(args, model): |
| | | optim_classes = dict( |
| | | adam=torch.optim.Adam, |
| | | fairseq_adam=FairseqAdam, |
| | | adamw=torch.optim.AdamW, |
| | | sgd=SGD, |
| | | adadelta=torch.optim.Adadelta, |
| | | adagrad=torch.optim.Adagrad, |
| | | adamax=torch.optim.Adamax, |
| | | asgd=torch.optim.ASGD, |
| | | lbfgs=torch.optim.LBFGS, |
| | | rmsprop=torch.optim.RMSprop, |
| | | rprop=torch.optim.Rprop, |
| | | ) |
| | | |
| | | optim_class = optim_classes.get(args.optim) |
| | | if optim_class is None: |
| | | raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") |
| | | optimizer = optim_class(model.parameters(), **args.optim_conf) |
| | | |
| | | optimizers = [optimizer] |
| | | return optimizers |