| | |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | |
| | | |
| | | def build_scheduler(args, optimizer): |
| | | def build_scheduler(args, optimizers): |
| | | scheduler_classes = dict( |
| | | ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, |
| | | lambdalr=torch.optim.lr_scheduler.LambdaLR, |
| | |
| | | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, |
| | | ) |
| | | |
| | | scheduler_class = scheduler_classes.get(args.scheduler) |
| | | if scheduler_class is None: |
| | | raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}") |
| | | scheduler = scheduler_class(optimizer, **args.scheduler_conf) |
| | | return scheduler |
| | | schedulers = [] |
| | | for i, optim in enumerate(optimizers, 1): |
| | | suf = "" if i == 1 else str(i) |
| | | name = getattr(args, f"scheduler{suf}") |
| | | conf = getattr(args, f"scheduler{suf}_conf") |
| | | if name is not None: |
| | | cls_ = scheduler_classes.get(name) |
| | | if cls_ is None: |
| | | raise ValueError( |
| | | f"must be one of {list(scheduler_classes)}: {name}" |
| | | ) |
| | | scheduler = cls_(optim, **conf) |
| | | else: |
| | | scheduler = None |
| | | |
| | | schedulers.append(scheduler) |
| | | |
| | | return schedulers |