From 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 11 九月 2023 17:40:03 +0800
Subject: [PATCH] Merge pull request #932 from alibaba-damo-academy/dev_lhn
---
funasr/build_utils/build_scheduler.py | 25 +++++++++++++++++++------
1 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
index f0e6d1f..4b9990e 100644
--- a/funasr/build_utils/build_scheduler.py
+++ b/funasr/build_utils/build_scheduler.py
@@ -8,7 +8,7 @@
from funasr.schedulers.warmup_lr import WarmupLR
-def build_scheduler(args, optimizer):
+def build_scheduler(args, optimizers):
scheduler_classes = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
@@ -24,8 +24,21 @@
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
- scheduler_class = scheduler_classes.get(args.scheduler)
- if scheduler_class is None:
- raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
- scheduler = scheduler_class(optimizer, **args.scheduler_conf)
- return scheduler
\ No newline at end of file
+ schedulers = []
+ for i, optim in enumerate(optimizers, 1):
+ suf = "" if i == 1 else str(i)
+ name = getattr(args, f"scheduler{suf}")
+ conf = getattr(args, f"scheduler{suf}_conf")
+ if name is not None:
+ cls_ = scheduler_classes.get(name)
+ if cls_ is None:
+ raise ValueError(
+ f"must be one of {list(scheduler_classes)}: {name}"
+ )
+ scheduler = cls_(optim, **conf)
+ else:
+ scheduler = None
+
+ schedulers.append(scheduler)
+
+ return schedulers
\ No newline at end of file
--
Gitblit v1.9.1