| | |
| | | |
| | | class TriStageLR(_LRScheduler, AbsBatchStepScheduler): |
| | | def __init__( |
| | | self, |
| | | optimizer: torch.optim.Optimizer, |
| | | last_epoch: int = -1, |
| | | phase_ratio: Optional[List[float]] = None, |
| | | init_lr_scale: float = 0.01, |
| | | final_lr_scale: float = 0.01, |
| | | self, |
| | | optimizer: torch.optim.Optimizer, |
| | | last_epoch: int = -1, |
| | | phase_ratio: Optional[List[float]] = None, |
| | | init_lr_scale: float = 0.01, |
| | | final_lr_scale: float = 0.01, |
| | | ): |
| | | self.optimizer = optimizer |
| | | self.last_epoch = last_epoch |
| | |
| | | self.decay_steps = int(self.max_update * self.phase_ratio[2]) |
| | | |
| | | self.warmup_rate = ( |
| | | (self.peak_lr - self.init_lr) / self.warmup_steps |
| | | if self.warmup_steps != 0 |
| | | else 0 |
| | | (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0 |
| | | ) |
| | | self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps |
| | | |