From c652f6814ac62eebb5fd1a55a303ee9110c87b58 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期日, 23 四月 2023 17:30:38 +0800
Subject: [PATCH] update

---
 funasr/build_utils/build_scheduler.py |   25 +++++++++++++++++++------
 1 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
index f0e6d1f..4b9990e 100644
--- a/funasr/build_utils/build_scheduler.py
+++ b/funasr/build_utils/build_scheduler.py
@@ -8,7 +8,7 @@
 from funasr.schedulers.warmup_lr import WarmupLR
 
 
-def build_scheduler(args, optimizer):
+def build_scheduler(args, optimizers):
     scheduler_classes = dict(
         ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
         lambdalr=torch.optim.lr_scheduler.LambdaLR,
@@ -24,8 +24,21 @@
         CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
     )
 
-    scheduler_class = scheduler_classes.get(args.scheduler)
-    if scheduler_class is None:
-        raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
-    scheduler = scheduler_class(optimizer, **args.scheduler_conf)
-    return scheduler
\ No newline at end of file
+    schedulers = []
+    for i, optim in enumerate(optimizers, 1):
+        suf = "" if i == 1 else str(i)
+        name = getattr(args, f"scheduler{suf}")
+        conf = getattr(args, f"scheduler{suf}_conf")
+        if name is not None:
+            cls_ = scheduler_classes.get(name)
+            if cls_ is None:
+                raise ValueError(
+                    f"must be one of {list(scheduler_classes)}: {name}"
+                )
+            scheduler = cls_(optim, **conf)
+        else:
+            scheduler = None
+
+        schedulers.append(scheduler)
+
+    return schedulers
\ No newline at end of file

--
Gitblit v1.9.1