liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/schedulers/tri_stage_scheduler.py
@@ -14,12 +14,12 @@
class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
    def __init__(
            self,
            optimizer: torch.optim.Optimizer,
            last_epoch: int = -1,
            phase_ratio: Optional[List[float]] = None,
            init_lr_scale: float = 0.01,
            final_lr_scale: float = 0.01,
        self,
        optimizer: torch.optim.Optimizer,
        last_epoch: int = -1,
        phase_ratio: Optional[List[float]] = None,
        init_lr_scale: float = 0.01,
        final_lr_scale: float = 0.01,
    ):
        self.optimizer = optimizer
        self.last_epoch = last_epoch
@@ -42,9 +42,7 @@
        self.decay_steps = int(self.max_update * self.phase_ratio[2])
        self.warmup_rate = (
            (self.peak_lr - self.init_lr) / self.warmup_steps
            if self.warmup_steps != 0
            else 0
            (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0
        )
        self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps