shixian.shi
2023-06-28 438c4663d2094ed7ce1762fa4f16cf89401b8bec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.multiprocessing
import torch.nn
import torch.optim
 
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
 
 
def build_scheduler(args, optimizers):
    scheduler_classes = dict(
        ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
        lambdalr=torch.optim.lr_scheduler.LambdaLR,
        steplr=torch.optim.lr_scheduler.StepLR,
        multisteplr=torch.optim.lr_scheduler.MultiStepLR,
        exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
        CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
        noamlr=NoamLR,
        warmuplr=WarmupLR,
        tri_stage=TriStageLR,
        cycliclr=torch.optim.lr_scheduler.CyclicLR,
        onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
        CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
    )
 
    schedulers = []
    for i, optim in enumerate(optimizers, 1):
        suf = "" if i == 1 else str(i)
        name = getattr(args, f"scheduler{suf}")
        conf = getattr(args, f"scheduler{suf}_conf")
        if name is not None:
            cls_ = scheduler_classes.get(name)
            if cls_ is None:
                raise ValueError(
                    f"must be one of {list(scheduler_classes)}: {name}"
                )
            scheduler = cls_(optim, **conf)
        else:
            scheduler = None
 
        schedulers.append(scheduler)
 
    return schedulers