From c652f6814ac62eebb5fd1a55a303ee9110c87b58 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期日, 23 四月 2023 17:30:38 +0800
Subject: [PATCH] update
---
funasr/build_utils/build_scheduler.py | 25 +++++++++++++++++++------
1 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
index f0e6d1f..4b9990e 100644
--- a/funasr/build_utils/build_scheduler.py
+++ b/funasr/build_utils/build_scheduler.py
@@ -8,7 +8,7 @@
from funasr.schedulers.warmup_lr import WarmupLR
-def build_scheduler(args, optimizer):
+def build_scheduler(args, optimizers):
scheduler_classes = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
@@ -24,8 +24,21 @@
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
- scheduler_class = scheduler_classes.get(args.scheduler)
- if scheduler_class is None:
- raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
- scheduler = scheduler_class(optimizer, **args.scheduler_conf)
- return scheduler
\ No newline at end of file
+ schedulers = []
+ for i, optim in enumerate(optimizers, 1):
+ suf = "" if i == 1 else str(i)
+ name = getattr(args, f"scheduler{suf}")
+ conf = getattr(args, f"scheduler{suf}_conf")
+ if name is not None:
+ cls_ = scheduler_classes.get(name)
+ if cls_ is None:
+ raise ValueError(
+ f"must be one of {list(scheduler_classes)}: {name}"
+ )
+ scheduler = cls_(optim, **conf)
+ else:
+ scheduler = None
+
+ schedulers.append(scheduler)
+
+ return schedulers
\ No newline at end of file
--
Gitblit v1.9.1