From cc2c1d1d53dea5d2c45f858d1baa5bd279f47987 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 31 五月 2023 14:39:25 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR

---
 funasr/build_utils/build_scheduler.py |   44 ++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 44 insertions(+), 0 deletions(-)

diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
new file mode 100644
index 0000000..4b9990e
--- /dev/null
+++ b/funasr/build_utils/build_scheduler.py
@@ -0,0 +1,44 @@
+import torch
+import torch.multiprocessing
+import torch.nn
+import torch.optim
+
+from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
+
+
+def build_scheduler(args, optimizers):
+    scheduler_classes = dict(
+        ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+        lambdalr=torch.optim.lr_scheduler.LambdaLR,
+        steplr=torch.optim.lr_scheduler.StepLR,
+        multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+        exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+        CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+        noamlr=NoamLR,
+        warmuplr=WarmupLR,
+        tri_stage=TriStageLR,
+        cycliclr=torch.optim.lr_scheduler.CyclicLR,
+        onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+        CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+    )
+
+    schedulers = []
+    for i, optim in enumerate(optimizers, 1):
+        suf = "" if i == 1 else str(i)
+        name = getattr(args, f"scheduler{suf}")
+        conf = getattr(args, f"scheduler{suf}_conf")
+        if name is not None:
+            cls_ = scheduler_classes.get(name)
+            if cls_ is None:
+                raise ValueError(
+                    f"must be one of {list(scheduler_classes)}: {name}"
+                )
+            scheduler = cls_(optim, **conf)
+        else:
+            scheduler = None
+
+        schedulers.append(scheduler)
+
+    return schedulers
\ No newline at end of file

--
Gitblit v1.9.1