From 35fc1108343e28d6f504cf5923658b0fb79ab443 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 06 二月 2023 16:22:24 +0800
Subject: [PATCH] update data2vec pretrain: tri_stage

---
 funasr/schedulers/tri_stage_scheduler.py |  108 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 funasr/tasks/abs_task.py                 |    2 +
 2 files changed, 110 insertions(+), 0 deletions(-)

diff --git a/funasr/schedulers/tri_stage_scheduler.py b/funasr/schedulers/tri_stage_scheduler.py
new file mode 100644
index 0000000..8dc71b4
--- /dev/null
+++ b/funasr/schedulers/tri_stage_scheduler.py
@@ -0,0 +1,108 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Optional, List
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+from typeguard import check_argument_types
+
+from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
+
+
+class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
+    def __init__(
+            self,
+            optimizer: torch.optim.Optimizer,
+            last_epoch: int = -1,
+            phase_ratio: Optional[List[float]] = None,
+            init_lr_scale: float = 0.01,
+            final_lr_scale: float = 0.01,
+    ):
+        assert check_argument_types()
+        self.optimizer = optimizer
+        self.last_epoch = last_epoch
+        self.phase_ratio = phase_ratio
+        self.init_lr_scale = init_lr_scale
+        self.final_lr_scale = final_lr_scale
+        self.optimizer_lr = self.optimizer.defaults["lr"]
+
+    def init_tri_stage_scheudler(self, max_update):
+        self.max_update = max_update
+        self.peak_lr = self.optimizer_lr
+        self.init_lr = self.init_lr_scale * self.optimizer_lr
+        self.final_lr = self.final_lr_scale * self.optimizer_lr
+
+        assert self.max_update > 0
+        assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
+        assert len(self.phase_ratio) == 3
+        self.warmup_steps = int(self.max_update * self.phase_ratio[0])
+        self.hold_steps = int(self.max_update * self.phase_ratio[1])
+        self.decay_steps = int(self.max_update * self.phase_ratio[2])
+
+        self.warmup_rate = (
+            (self.peak_lr - self.init_lr) / self.warmup_steps
+            if self.warmup_steps != 0
+            else 0
+        )
+        self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
+
+        # initial learning rate
+        self.lr = self.init_lr
+
+        # __init__() must be invoked before setting field
+        # because step() is also invoked in __init__()
+        self.set_optimizer_lr(self.lr)
+        super().__init__(self.optimizer, self.last_epoch)
+
+    def _decide_stage(self, update_step):
+        """
+        return stage, and the corresponding steps within the current stage
+        """
+        if update_step < self.warmup_steps:
+            # warmup state
+            return 0, update_step
+
+        offset = self.warmup_steps
+
+        if update_step < offset + self.hold_steps:
+            # hold stage
+            return 1, update_step - offset
+
+        offset += self.hold_steps
+
+        if update_step <= offset + self.decay_steps:
+            # decay stage
+            return 2, update_step - offset
+
+        offset += self.decay_steps
+
+        # still here ? constant lr stage
+        return 3, update_step - offset
+
+    def step_update(self, num_updates):
+        """Update the learning rate after each update."""
+        stage, steps_in_stage = self._decide_stage(num_updates)
+        if stage == 0:
+            self.lr = self.init_lr + self.warmup_rate * steps_in_stage
+        elif stage == 1:
+            self.lr = self.peak_lr
+        elif stage == 2:
+            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
+        elif stage == 3:
+            self.lr = self.final_lr
+        else:
+            raise ValueError("Undefined stage")
+        self.set_optimizer_lr(self.lr)
+
+    def set_optimizer_lr(self, lr):
+        for param_group in self.optimizer.param_groups:
+            param_group["lr"] = lr
+
+    def get_lr(self):
+        step_num = self.last_epoch + 1
+        self.step_update(step_num)
+        return [self.lr]
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 83926f4..1c8e640 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -50,6 +50,7 @@
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
 from funasr.schedulers.warmup_lr import WarmupLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -151,6 +152,7 @@
     CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
     noamlr=NoamLR,
     warmuplr=WarmupLR,
+    tri_stage=TriStageLR,
     cycliclr=torch.optim.lr_scheduler.CyclicLR,
     onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
     CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,

--
Gitblit v1.9.1