From 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 14:32:43 +0800
Subject: [PATCH] batch
---
funasr/schedulers/lambdalr_cus.py | 49 +++++++++++++++++++++++++++++--------------------
1 files changed, 29 insertions(+), 20 deletions(-)
diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 5aad049..e3bb1fb 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -1,32 +1,41 @@
-
import torch
from torch.optim.lr_scheduler import _LRScheduler
+
+# class CustomLambdaLR(_LRScheduler):
+# def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+# self.warmup_steps = warmup_steps
+# super().__init__(optimizer, last_epoch)
+#
+# def get_lr(self):
+# if self.last_epoch < self.warmup_steps:
+# return [
+# base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
+# ]
+# else:
+# return [base_lr for base_lr in self.base_lrs]
+
+
class CustomLambdaLR(_LRScheduler):
- def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+ def __init__(
+ self,
+ optimizer,
+ warmup_steps: int = 25000,
+ total_steps: int = 500000,
+ last_epoch=-1,
+ verbose=False,
+ ):
self.warmup_steps = warmup_steps
- super().__init__(optimizer, last_epoch)
+ self.total_steps = total_steps
+ super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
- if self.last_epoch < self.warmup_steps:
- return [
- base_lr * min(self.last_epoch / self.warmup_steps, 1)
- for base_lr in self.base_lrs
- ]
- else:
- return [base_lr for base_lr in self.base_lrs]
-
-
-class CustomLambdaLR(_LRScheduler):
- def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
- self.warmup_steps = train_config.warmup_steps
- self.total_steps = train_config.total_steps
- super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- step = self._step_count
+ step = self.last_epoch + 1
if step < self.warmup_steps:
lr_scale = step / self.warmup_steps
else:
- lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
+ lr_scale = max(
+ 0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
+ )
return [base_lr * lr_scale for base_lr in self.base_lrs]
--
Gitblit v1.9.1