嘉渊
2023-04-23 c652f6814ac62eebb5fd1a55a303ee9110c87b58
funasr/build_utils/build_scheduler.py
@@ -8,7 +8,7 @@
from funasr.schedulers.warmup_lr import WarmupLR
def build_scheduler(args, optimizer):
def build_scheduler(args, optimizers):
    scheduler_classes = dict(
        ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
        lambdalr=torch.optim.lr_scheduler.LambdaLR,
@@ -24,8 +24,21 @@
        CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
    )
    scheduler_class = scheduler_classes.get(args.scheduler)
    if scheduler_class is None:
        raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
    scheduler = scheduler_class(optimizer, **args.scheduler_conf)
    return scheduler
    schedulers = []
    for i, optim in enumerate(optimizers, 1):
        suf = "" if i == 1 else str(i)
        name = getattr(args, f"scheduler{suf}")
        conf = getattr(args, f"scheduler{suf}_conf")
        if name is not None:
            cls_ = scheduler_classes.get(name)
            if cls_ is None:
                raise ValueError(
                    f"must be one of {list(scheduler_classes)}: {name}"
                )
            scheduler = cls_(optim, **conf)
        else:
            scheduler = None
        schedulers.append(scheduler)
    return schedulers