| | |
| | | """Noam learning rate scheduler module.""" |
| | | |
| | | from typing import Union |
| | | import warnings |
| | | |
| | | import torch |
| | | from torch.optim.lr_scheduler import _LRScheduler |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler |
| | | |
| | |
| | | warmup_steps: Union[int, float] = 25000, |
| | | last_epoch: int = -1, |
| | | ): |
| | | assert check_argument_types() |
| | | self.model_size = model_size |
| | | self.warmup_steps = warmup_steps |
| | | |
| | |
| | | def get_lr(self): |
| | | step_num = self.last_epoch + 1 |
| | | return [ |
| | | lr |
| | | * self.model_size**-0.5 |
| | | * min(step_num**-0.5, step_num * self.warmup_steps**-1.5) |
| | | lr * self.model_size**-0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5) |
| | | for lr in self.base_lrs |
| | | ] |