From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/schedulers/lambdalr_cus.py | 42 +++++++++++++++++++++++++++++++++---------
1 files changed, 33 insertions(+), 9 deletions(-)
diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 0123cc2..e3bb1fb 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -1,17 +1,41 @@
-
import torch
from torch.optim.lr_scheduler import _LRScheduler
+
+# class CustomLambdaLR(_LRScheduler):
+# def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+# self.warmup_steps = warmup_steps
+# super().__init__(optimizer, last_epoch)
+#
+# def get_lr(self):
+# if self.last_epoch < self.warmup_steps:
+# return [
+# base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
+# ]
+# else:
+# return [base_lr for base_lr in self.base_lrs]
+
+
class CustomLambdaLR(_LRScheduler):
- def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+ def __init__(
+ self,
+ optimizer,
+ warmup_steps: int = 25000,
+ total_steps: int = 500000,
+ last_epoch=-1,
+ verbose=False,
+ ):
self.warmup_steps = warmup_steps
- super().__init__(optimizer, last_epoch)
+ self.total_steps = total_steps
+ super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
- if self.last_epoch < self.warmup_steps:
- return [
- base_lr * min(self.last_epoch / self.warmup_steps, 1)
- for base_lr in self.base_lrs
- ]
+
+ step = self.last_epoch + 1
+ if step < self.warmup_steps:
+ lr_scale = step / self.warmup_steps
else:
- return [base_lr for base_lr in self.base_lrs]
\ No newline at end of file
+ lr_scale = max(
+ 0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
+ )
+ return [base_lr * lr_scale for base_lr in self.base_lrs]
--
Gitblit v1.9.1