From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/schedulers/lambdalr_cus.py |   42 +++++++++++++++++++++++++++++++++---------
 1 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 0123cc2..e3bb1fb 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -1,17 +1,41 @@
-
 import torch
 from torch.optim.lr_scheduler import _LRScheduler
 
+
+# class CustomLambdaLR(_LRScheduler):
+#     def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+#         self.warmup_steps = warmup_steps
+#         super().__init__(optimizer, last_epoch)
+#
+#     def get_lr(self):
+#         if self.last_epoch < self.warmup_steps:
+#             return [
+#                 base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
+#             ]
+#         else:
+#             return [base_lr for base_lr in self.base_lrs]
+
+
 class CustomLambdaLR(_LRScheduler):
-    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+    def __init__(
+        self,
+        optimizer,
+        warmup_steps: int = 25000,
+        total_steps: int = 500000,
+        last_epoch=-1,
+        verbose=False,
+    ):
         self.warmup_steps = warmup_steps
-        super().__init__(optimizer, last_epoch)
+        self.total_steps = total_steps
+        super().__init__(optimizer, last_epoch, verbose)
 
     def get_lr(self):
-        if self.last_epoch < self.warmup_steps:
-            return [
-                base_lr * min(self.last_epoch / self.warmup_steps, 1)
-                for base_lr in self.base_lrs
-            ]
+
+        step = self.last_epoch + 1
+        if step < self.warmup_steps:
+            lr_scale = step / self.warmup_steps
         else:
-            return [base_lr for base_lr in self.base_lrs]
\ No newline at end of file
+            lr_scale = max(
+                0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
+            )
+        return [base_lr * lr_scale for base_lr in self.base_lrs]

--
Gitblit v1.9.1