From cc2c1d1d53dea5d2c45f858d1baa5bd279f47987 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 31 五月 2023 14:39:25 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/build_utils/build_scheduler.py | 44 ++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 44 insertions(+), 0 deletions(-)
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
new file mode 100644
index 0000000..4b9990e
--- /dev/null
+++ b/funasr/build_utils/build_scheduler.py
@@ -0,0 +1,44 @@
+import torch
+import torch.multiprocessing
+import torch.nn
+import torch.optim
+
+from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
+
+
+def build_scheduler(args, optimizers):
+ scheduler_classes = dict(
+ ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+ lambdalr=torch.optim.lr_scheduler.LambdaLR,
+ steplr=torch.optim.lr_scheduler.StepLR,
+ multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+ exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+ CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+ noamlr=NoamLR,
+ warmuplr=WarmupLR,
+ tri_stage=TriStageLR,
+ cycliclr=torch.optim.lr_scheduler.CyclicLR,
+ onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+ CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+ )
+
+ schedulers = []
+ for i, optim in enumerate(optimizers, 1):
+ suf = "" if i == 1 else str(i)
+ name = getattr(args, f"scheduler{suf}")
+ conf = getattr(args, f"scheduler{suf}_conf")
+ if name is not None:
+ cls_ = scheduler_classes.get(name)
+ if cls_ is None:
+ raise ValueError(
+ f"must be one of {list(scheduler_classes)}: {name}"
+ )
+ scheduler = cls_(optim, **conf)
+ else:
+ scheduler = None
+
+ schedulers.append(scheduler)
+
+ return schedulers
\ No newline at end of file
--
Gitblit v1.9.1