From 1d1ef01b4e23630a99a3be7e9d1dce9550a793e9 Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期四, 11 五月 2023 16:26:24 +0800
Subject: [PATCH] Merge branch 'main' into dev_smohan

---
 funasr/train/trainer.py |   14 ++++++++++++++
 1 files changed, 14 insertions(+), 0 deletions(-)

diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 7c187e9..a40f031 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -95,6 +95,7 @@
     use_pai: bool
     oss_bucket: Union[oss2.Bucket, None]
     batch_interval: int
+    bias_grad_times: float
 
 class Trainer:
     """Trainer having a optimizer.
@@ -546,8 +547,11 @@
         no_forward_run = options.no_forward_run
         ngpu = options.ngpu
         use_wandb = options.use_wandb
+        bias_grad_times = options.bias_grad_times
         distributed = distributed_option.distributed
 
+        if bias_grad_times != 1.0:
+            logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
         if log_interval is None:
             try:
                 log_interval = max(len(iterator) // 20, 10)
@@ -690,6 +694,16 @@
                         scale_factor=0.55,
                     )
 
+                # for contextual training
+                if bias_grad_times != 1.0:
+                    # contextual related parameter names
+                    cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
+                    for name, param in model.named_parameters():
+                        for cr_pname in cr_pnames:
+                            if cr_pname in name:
+                                param.grad *= bias_grad_times
+                                continue
+
                 # compute the gradient norm to check if it is normal or not
                 grad_norm = torch.nn.utils.clip_grad_norm_(
                     model.parameters(),

--
Gitblit v1.9.1