jmwang66
2023-02-06 35fc1108343e28d6f504cf5923658b0fb79ab443
update data2vec pretrain: tri_stage
1个文件已修改
1个文件已添加
110 ■■■■■ 已修改文件
funasr/schedulers/tri_stage_scheduler.py 108 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/schedulers/tri_stage_scheduler.py
New file
@@ -0,0 +1,108 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, List
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typeguard import check_argument_types
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
    def __init__(
            self,
            optimizer: torch.optim.Optimizer,
            last_epoch: int = -1,
            phase_ratio: Optional[List[float]] = None,
            init_lr_scale: float = 0.01,
            final_lr_scale: float = 0.01,
    ):
        assert check_argument_types()
        self.optimizer = optimizer
        self.last_epoch = last_epoch
        self.phase_ratio = phase_ratio
        self.init_lr_scale = init_lr_scale
        self.final_lr_scale = final_lr_scale
        self.optimizer_lr = self.optimizer.defaults["lr"]
    def init_tri_stage_scheudler(self, max_update):
        self.max_update = max_update
        self.peak_lr = self.optimizer_lr
        self.init_lr = self.init_lr_scale * self.optimizer_lr
        self.final_lr = self.final_lr_scale * self.optimizer_lr
        assert self.max_update > 0
        assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
        assert len(self.phase_ratio) == 3
        self.warmup_steps = int(self.max_update * self.phase_ratio[0])
        self.hold_steps = int(self.max_update * self.phase_ratio[1])
        self.decay_steps = int(self.max_update * self.phase_ratio[2])
        self.warmup_rate = (
            (self.peak_lr - self.init_lr) / self.warmup_steps
            if self.warmup_steps != 0
            else 0
        )
        self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
        # initial learning rate
        self.lr = self.init_lr
        # __init__() must be invoked before setting field
        # because step() is also invoked in __init__()
        self.set_optimizer_lr(self.lr)
        super().__init__(self.optimizer, self.last_epoch)
    def _decide_stage(self, update_step):
        """
        return stage, and the corresponding steps within the current stage
        """
        if update_step < self.warmup_steps:
            # warmup state
            return 0, update_step
        offset = self.warmup_steps
        if update_step < offset + self.hold_steps:
            # hold stage
            return 1, update_step - offset
        offset += self.hold_steps
        if update_step <= offset + self.decay_steps:
            # decay stage
            return 2, update_step - offset
        offset += self.decay_steps
        # still here ? constant lr stage
        return 3, update_step - offset
    def step_update(self, num_updates):
        """Update the learning rate after each update."""
        stage, steps_in_stage = self._decide_stage(num_updates)
        if stage == 0:
            self.lr = self.init_lr + self.warmup_rate * steps_in_stage
        elif stage == 1:
            self.lr = self.peak_lr
        elif stage == 2:
            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
        elif stage == 3:
            self.lr = self.final_lr
        else:
            raise ValueError("Undefined stage")
        self.set_optimizer_lr(self.lr)
    def set_optimizer_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr
    def get_lr(self):
        step_num = self.last_epoch + 1
        self.step_update(step_num)
        return [self.lr]
funasr/tasks/abs_task.py
@@ -50,6 +50,7 @@
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.warmup_lr import WarmupLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -151,6 +152,7 @@
    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
    noamlr=NoamLR,
    warmuplr=WarmupLR,
    tri_stage=TriStageLR,
    cycliclr=torch.optim.lr_scheduler.CyclicLR,
    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,