From 11cf10e433c173efd892766b669e0bba57253fed Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 14:52:20 +0800
Subject: [PATCH] Dev gzf exp (#1678)
---
funasr/schedulers/lambdalr_cus.py | 42 +++++++++++++++++++++++++-----------------
1 files changed, 25 insertions(+), 17 deletions(-)
diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 19ad7a8..e3bb1fb 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -2,28 +2,36 @@
from torch.optim.lr_scheduler import _LRScheduler
+# class CustomLambdaLR(_LRScheduler):
+# def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+# self.warmup_steps = warmup_steps
+# super().__init__(optimizer, last_epoch)
+#
+# def get_lr(self):
+# if self.last_epoch < self.warmup_steps:
+# return [
+# base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
+# ]
+# else:
+# return [base_lr for base_lr in self.base_lrs]
+
+
class CustomLambdaLR(_LRScheduler):
- def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+ def __init__(
+ self,
+ optimizer,
+ warmup_steps: int = 25000,
+ total_steps: int = 500000,
+ last_epoch=-1,
+ verbose=False,
+ ):
self.warmup_steps = warmup_steps
- super().__init__(optimizer, last_epoch)
+ self.total_steps = total_steps
+ super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
- if self.last_epoch < self.warmup_steps:
- return [
- base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
- ]
- else:
- return [base_lr for base_lr in self.base_lrs]
-
-class CustomLambdaLR(_LRScheduler):
- def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
- self.warmup_steps = train_config.warmup_steps
- self.total_steps = train_config.total_steps
- super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- step = self._step_count
+ step = self.last_epoch + 1
if step < self.warmup_steps:
lr_scale = step / self.warmup_steps
else:
--
Gitblit v1.9.1