| | |
| | | for base_lr in self.base_lrs |
| | | ] |
| | | else: |
| | | return [base_lr for base_lr in self.base_lrs] |
| | | return [base_lr for base_lr in self.base_lrs] |
| | | |
| | | |
| | | class CustomLambdaLR(_LRScheduler): |
| | | def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False): |
| | | self.warmup_steps = train_config.warmup_steps |
| | | self.total_steps = train_config.total_steps |
| | | super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose) |
| | | |
| | | def get_lr(self): |
| | | step = self._step_count |
| | | if step < self.warmup_steps: |
| | | lr_scale = step / self.warmup_steps |
| | | else: |
| | | lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)) |
| | | return [base_lr * lr_scale for base_lr in self.base_lrs] |