| | |
| | | from funasr.utils.build_dataloader import build_dataloader |
| | | from funasr.utils.build_distributed import build_distributed |
| | | from funasr.utils.prepare_data import prepare_data |
| | | from funasr.utils.build_optimizer import build_optimizer |
| | | from funasr.utils.build_scheduler import build_scheduler |
| | | from funasr.utils.types import str2bool |
| | | |
| | | |
| | |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | |
| | | # optimizers = cls.build_optimizers(args, model=model) |
| | | # schedulers = [] |
| | | # for i, optim in enumerate(optimizers, 1): |
| | | # suf = "" if i == 1 else str(i) |
| | | # name = getattr(args, f"scheduler{suf}") |
| | | # conf = getattr(args, f"scheduler{suf}_conf") |
| | | # if name is not None: |
| | | # cls_ = scheduler_classes.get(name) |
| | | # if cls_ is None: |
| | | # raise ValueError( |
| | | # f"must be one of {list(scheduler_classes)}: {name}" |
| | | # ) |
| | | # scheduler = cls_(optim, **conf) |
| | | # else: |
| | | # scheduler = None |
| | | # |
| | | # schedulers.append(scheduler) |
| | | model = build_model(args) |
| | | optimizers = build_optimizer(args, model=model) |
| | | schedule = build_scheduler(args) |
| | |
| | | |
| | | def build_model(args): |
| | | if args.token_list is not None: |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | with open(args.token_list) as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | args.token_list = list(token_list) |
| | | vocab_size = len(token_list) |
| New file |
| | |
| | | import torch |
| | | |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.optimizers.sgd import SGD |
| | | |
| | | |
| | | def build_optimizer(args, model): |
| | | optim_classes = dict( |
| | | adam=torch.optim.Adam, |
| | | fairseq_adam=FairseqAdam, |
| | | adamw=torch.optim.AdamW, |
| | | sgd=SGD, |
| | | adadelta=torch.optim.Adadelta, |
| | | adagrad=torch.optim.Adagrad, |
| | | adamax=torch.optim.Adamax, |
| | | asgd=torch.optim.ASGD, |
| | | lbfgs=torch.optim.LBFGS, |
| | | rmsprop=torch.optim.RMSprop, |
| | | rprop=torch.optim.Rprop, |
| | | ) |
| | | |
| | | optim_class = optim_classes.get(args.optim) |
| | | if optim_class is None: |
| | | raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") |
| | | optimizer = optim_class(model.parameters(), **args.optim_conf) |
| | | return optimizer |
| New file |
| | |
| | | import torch |
| | | import torch.multiprocessing |
| | | import torch.nn |
| | | import torch.optim |
| | | |
| | | from funasr.schedulers.noam_lr import NoamLR |
| | | from funasr.schedulers.tri_stage_scheduler import TriStageLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | |
| | | |
| | | def build_scheduler(args, optimizer): |
| | | scheduler_classes = dict( |
| | | ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, |
| | | lambdalr=torch.optim.lr_scheduler.LambdaLR, |
| | | steplr=torch.optim.lr_scheduler.StepLR, |
| | | multisteplr=torch.optim.lr_scheduler.MultiStepLR, |
| | | exponentiallr=torch.optim.lr_scheduler.ExponentialLR, |
| | | CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, |
| | | noamlr=NoamLR, |
| | | warmuplr=WarmupLR, |
| | | tri_stage=TriStageLR, |
| | | cycliclr=torch.optim.lr_scheduler.CyclicLR, |
| | | onecyclelr=torch.optim.lr_scheduler.OneCycleLR, |
| | | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, |
| | | ) |
| | | |
| | | scheduler_class = scheduler_classes.get(args.scheduler) |
| | | if scheduler_class is None: |
| | | raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}") |
| | | scheduler = scheduler_class(optimizer, **args.scheduler_conf) |
| | | return scheduler |