From 11cf10e433c173efd892766b669e0bba57253fed Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 14:52:20 +0800
Subject: [PATCH] Dev gzf exp (#1678)

---
 funasr/schedulers/lambdalr_cus.py |   42 +++++++++++++++++++++++++-----------------
 1 files changed, 25 insertions(+), 17 deletions(-)

diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 19ad7a8..e3bb1fb 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -2,28 +2,36 @@
 from torch.optim.lr_scheduler import _LRScheduler
 
 
+# class CustomLambdaLR(_LRScheduler):
+#     def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+#         self.warmup_steps = warmup_steps
+#         super().__init__(optimizer, last_epoch)
+#
+#     def get_lr(self):
+#         if self.last_epoch < self.warmup_steps:
+#             return [
+#                 base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
+#             ]
+#         else:
+#             return [base_lr for base_lr in self.base_lrs]
+
+
 class CustomLambdaLR(_LRScheduler):
-    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+    def __init__(
+        self,
+        optimizer,
+        warmup_steps: int = 25000,
+        total_steps: int = 500000,
+        last_epoch=-1,
+        verbose=False,
+    ):
         self.warmup_steps = warmup_steps
-        super().__init__(optimizer, last_epoch)
+        self.total_steps = total_steps
+        super().__init__(optimizer, last_epoch, verbose)
 
     def get_lr(self):
-        if self.last_epoch < self.warmup_steps:
-            return [
-                base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
-            ]
-        else:
-            return [base_lr for base_lr in self.base_lrs]
 
-
-class CustomLambdaLR(_LRScheduler):
-    def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
-        self.warmup_steps = train_config.warmup_steps
-        self.total_steps = train_config.total_steps
-        super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
-
-    def get_lr(self):
-        step = self._step_count
+        step = self.last_epoch + 1
         if step < self.warmup_steps:
             lr_scale = step / self.warmup_steps
         else:

--
Gitblit v1.9.1