| | |
| | | |
| | | import torch |
| | | from torch.optim.lr_scheduler import _LRScheduler |
| | | |
| | | |
| | | # class CustomLambdaLR(_LRScheduler): |
| | | # def __init__(self, optimizer, warmup_steps, last_epoch=-1): |
| | | # self.warmup_steps = warmup_steps |
| | | # super().__init__(optimizer, last_epoch) |
| | | # |
| | | # def get_lr(self): |
| | | # if self.last_epoch < self.warmup_steps: |
| | | # return [ |
| | | # base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs |
| | | # ] |
| | | # else: |
| | | # return [base_lr for base_lr in self.base_lrs] |
| | | |
| | | |
| | | class CustomLambdaLR(_LRScheduler): |
| | | def __init__(self, optimizer, warmup_steps, last_epoch=-1): |
| | | def __init__( |
| | | self, |
| | | optimizer, |
| | | warmup_steps: int = 25000, |
| | | total_steps: int = 500000, |
| | | last_epoch=-1, |
| | | verbose=False, |
| | | ): |
| | | self.warmup_steps = warmup_steps |
| | | super().__init__(optimizer, last_epoch) |
| | | self.total_steps = total_steps |
| | | super().__init__(optimizer, last_epoch, verbose) |
| | | |
| | | def get_lr(self): |
| | | if self.last_epoch < self.warmup_steps: |
| | | return [ |
| | | base_lr * min(self.last_epoch / self.warmup_steps, 1) |
| | | for base_lr in self.base_lrs |
| | | ] |
| | | |
| | | step = self.last_epoch + 1 |
| | | if step < self.warmup_steps: |
| | | lr_scale = step / self.warmup_steps |
| | | else: |
| | | return [base_lr for base_lr in self.base_lrs] |
| | | lr_scale = max( |
| | | 0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
| | | ) |
| | | return [base_lr * lr_scale for base_lr in self.base_lrs] |