From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/schedulers/lambdalr_cus.py |   17 ++++++++++++++++-
 1 files changed, 16 insertions(+), 1 deletions(-)

diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 0123cc2..5aad049 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -14,4 +14,19 @@
                 for base_lr in self.base_lrs
             ]
         else:
-            return [base_lr for base_lr in self.base_lrs]
\ No newline at end of file
+            return [base_lr for base_lr in self.base_lrs]
+        
+        
+class CustomLambdaLR(_LRScheduler):
+    def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
+        self.warmup_steps = train_config.warmup_steps
+        self.total_steps = train_config.total_steps
+        super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
+
+    def get_lr(self):
+        step = self._step_count
+        if step < self.warmup_steps:
+            lr_scale = step / self.warmup_steps
+        else:
+            lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
+        return [base_lr * lr_scale for base_lr in self.base_lrs]

--
Gitblit v1.9.1