| | |
| | | from torch.optim.lr_scheduler import _LRScheduler |
| | | |
| | | |
| | | # class CustomLambdaLR(_LRScheduler): |
| | | # def __init__(self, optimizer, warmup_steps, last_epoch=-1): |
| | | # self.warmup_steps = warmup_steps |
| | | # super().__init__(optimizer, last_epoch) |
| | | # |
| | | # def get_lr(self): |
| | | # if self.last_epoch < self.warmup_steps: |
| | | # return [ |
| | | # base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs |
| | | # ] |
| | | # else: |
| | | # return [base_lr for base_lr in self.base_lrs] |
| | | |
| | | |
| | | class CustomLambdaLR(_LRScheduler): |
| | | def __init__(self, optimizer, warmup_steps, last_epoch=-1): |
| | | def __init__( |
| | | self, |
| | | optimizer, |
| | | warmup_steps: int = 25000, |
| | | total_steps: int = 500000, |
| | | last_epoch=-1, |
| | | verbose=False, |
| | | ): |
| | | self.warmup_steps = warmup_steps |
| | | super().__init__(optimizer, last_epoch) |
| | | self.total_steps = total_steps |
| | | super().__init__(optimizer, last_epoch, verbose) |
| | | |
| | | def get_lr(self): |
| | | if self.last_epoch < self.warmup_steps: |
| | | return [ |
| | | base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs |
| | | ] |
| | | else: |
| | | return [base_lr for base_lr in self.base_lrs] |
| | | |
| | | |
| | | class CustomLambdaLR(_LRScheduler): |
| | | def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False): |
| | | self.warmup_steps = train_config.warmup_steps |
| | | self.total_steps = train_config.total_steps |
| | | super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose) |
| | | |
| | | def get_lr(self): |
| | | step = self._step_count |
| | | step = self.last_epoch + 1 |
| | | if step < self.warmup_steps: |
| | | lr_scale = step / self.warmup_steps |
| | | else: |