From fce4e1d1b48f23cd8332e60afce3df8d6209a6a7 Mon Sep 17 00:00:00 2001
From: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Date: 星期四, 11 四月 2024 14:59:22 +0800
Subject: [PATCH] SenseVoice对富文本解码的参数 (#1608)

---
 funasr/schedulers/lambdalr_cus.py |   17 ++++++++++++++++-
 1 files changed, 16 insertions(+), 1 deletions(-)

diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 0123cc2..5aad049 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -14,4 +14,19 @@
                 for base_lr in self.base_lrs
             ]
         else:
-            return [base_lr for base_lr in self.base_lrs]
\ No newline at end of file
+            return [base_lr for base_lr in self.base_lrs]
+        
+        
+class CustomLambdaLR(_LRScheduler):
+    def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
+        self.warmup_steps = train_config.warmup_steps
+        self.total_steps = train_config.total_steps
+        super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
+
+    def get_lr(self):
+        step = self._step_count
+        if step < self.warmup_steps:
+            lr_scale = step / self.warmup_steps
+        else:
+            lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
+        return [base_lr * lr_scale for base_lr in self.base_lrs]

--
Gitblit v1.9.1