From b63e73ae4f5df9d4ed9fb0bee12ac2cc09d7f523 Mon Sep 17 00:00:00 2001
From: zhaomingwork <zhaomingwork@qq.com>
Date: 星期五, 19 五月 2023 14:30:13 +0800
Subject: [PATCH] add asr wss address input to html

---
 funasr/train/trainer.py |   18 ++++++++++++++++--
 1 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 7c187e9..4052448 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -39,7 +39,7 @@
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.recursive_op import recursive_average
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.reporter import Reporter
 from funasr.train.reporter import SubReporter
@@ -95,6 +95,7 @@
     use_pai: bool
     oss_bucket: Union[oss2.Bucket, None]
     batch_interval: int
+    bias_grad_times: float
 
 class Trainer:
     """Trainer having a optimizer.
@@ -165,7 +166,7 @@
     @classmethod
     def run(
         cls,
-        model: AbsESPnetModel,
+        model: FunASRModel,
         optimizers: Sequence[torch.optim.Optimizer],
         schedulers: Sequence[Optional[AbsScheduler]],
         train_iter_factory: AbsIterFactory,
@@ -546,8 +547,11 @@
         no_forward_run = options.no_forward_run
         ngpu = options.ngpu
         use_wandb = options.use_wandb
+        bias_grad_times = options.bias_grad_times
         distributed = distributed_option.distributed
 
+        if bias_grad_times != 1.0:
+            logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
         if log_interval is None:
             try:
                 log_interval = max(len(iterator) // 20, 10)
@@ -690,6 +694,16 @@
                         scale_factor=0.55,
                     )
 
+                # for contextual training
+                if bias_grad_times != 1.0:
+                    # contextual related parameter names
+                    cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
+                    for name, param in model.named_parameters():
+                        for cr_pname in cr_pnames:
+                            if cr_pname in name:
+                                param.grad *= bias_grad_times
+                                continue
+
                 # compute the gradient norm to check if it is normal or not
                 grad_norm = torch.nn.utils.clip_grad_norm_(
                     model.parameters(),

--
Gitblit v1.9.1