# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from typing import Optional, List import torch from torch.optim.lr_scheduler import _LRScheduler from typeguard import check_argument_types from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler class TriStageLR(_LRScheduler, AbsBatchStepScheduler): def __init__( self, optimizer: torch.optim.Optimizer, last_epoch: int = -1, phase_ratio: Optional[List[float]] = None, init_lr_scale: float = 0.01, final_lr_scale: float = 0.01, ): assert check_argument_types() self.optimizer = optimizer self.last_epoch = last_epoch self.phase_ratio = phase_ratio self.init_lr_scale = init_lr_scale self.final_lr_scale = final_lr_scale self.optimizer_lr = self.optimizer.defaults["lr"] def init_tri_stage_scheudler(self, max_update): self.max_update = max_update self.peak_lr = self.optimizer_lr self.init_lr = self.init_lr_scale * self.optimizer_lr self.final_lr = self.final_lr_scale * self.optimizer_lr assert self.max_update > 0 assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1" assert len(self.phase_ratio) == 3 self.warmup_steps = int(self.max_update * self.phase_ratio[0]) self.hold_steps = int(self.max_update * self.phase_ratio[1]) self.decay_steps = int(self.max_update * self.phase_ratio[2]) self.warmup_rate = ( (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0 ) self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps # initial learning rate self.lr = self.init_lr # __init__() must be invoked before setting field # because step() is also invoked in __init__() self.set_optimizer_lr(self.lr) super().__init__(self.optimizer, self.last_epoch) def _decide_stage(self, update_step): """ return stage, and the corresponding steps within the current stage """ if update_step < self.warmup_steps: # warmup state return 0, update_step offset = self.warmup_steps if update_step < offset + self.hold_steps: # hold stage return 1, update_step - offset offset += self.hold_steps if update_step <= offset + self.decay_steps: # decay stage return 2, update_step - offset offset += self.decay_steps # still here ? constant lr stage return 3, update_step - offset def step_update(self, num_updates): """Update the learning rate after each update.""" stage, steps_in_stage = self._decide_stage(num_updates) if stage == 0: self.lr = self.init_lr + self.warmup_rate * steps_in_stage elif stage == 1: self.lr = self.peak_lr elif stage == 2: self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage) elif stage == 3: self.lr = self.final_lr else: raise ValueError("Undefined stage") self.set_optimizer_lr(self.lr) def set_optimizer_lr(self, lr): for param_group in self.optimizer.param_groups: param_group["lr"] = lr def get_lr(self): step_num = self.last_epoch + 1 self.step_update(step_num) return [self.lr]