From 3ac03e448b7673604eb86f619b27521fca55f34d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 01:36:39 +0800
Subject: [PATCH] train & finetune llm-asr (#1519)

---
 funasr/schedulers/lambdalr_cus.py |   17 ++++++++++++++++-
 1 files changed, 16 insertions(+), 1 deletions(-)

diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 0123cc2..5aad049 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -14,4 +14,19 @@
                 for base_lr in self.base_lrs
             ]
         else:
-            return [base_lr for base_lr in self.base_lrs]
\ No newline at end of file
+            return [base_lr for base_lr in self.base_lrs]
+        
+        
+class CustomLambdaLR(_LRScheduler):
+    def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
+        self.warmup_steps = train_config.warmup_steps
+        self.total_steps = train_config.total_steps
+        super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
+
+    def get_lr(self):
+        step = self._step_count
+        if step < self.warmup_steps:
+            lr_scale = step / self.warmup_steps
+        else:
+            lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
+        return [base_lr * lr_scale for base_lr in self.base_lrs]

--
Gitblit v1.9.1