From 653fffdf29fc77ea9203d0cffdcc760f55a61dd5 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 05 五月 2023 16:14:20 +0800
Subject: [PATCH] update lr and bias_grad_times
---
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py | 2 +-
funasr/train/trainer.py | 15 +++++++++++++++
funasr/tasks/abs_task.py | 6 ++++++
3 files changed, 22 insertions(+), 1 deletions(-)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py
index e4d6682..34c7cf9 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py
@@ -31,6 +31,6 @@
params.dataset_type = "large" # finetune contextual paraformer妯″瀷鍙兘浣跨敤large dataset
params.batch_bins = 200000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+ params.lr = 0.0002 # 璁剧疆瀛︿範鐜�
modelscope_finetune(params)
\ No newline at end of file
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 3d2004c..31057f9 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -549,6 +549,12 @@
help="The number of gradient accumulation",
)
group.add_argument(
+ "--bias_grad_times",
+ type=float,
+ default=1.0,
+ help="To scale the gradient of contextual related params",
+ )
+ group.add_argument(
"--no_forward_run",
type=str2bool,
default=False,
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 7c187e9..405268a 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -3,6 +3,7 @@
"""Trainer module."""
import argparse
+from audioop import bias
from contextlib import contextmanager
import dataclasses
from dataclasses import is_dataclass
@@ -95,6 +96,7 @@
use_pai: bool
oss_bucket: Union[oss2.Bucket, None]
batch_interval: int
+ bias_grad_times: float
class Trainer:
"""Trainer having a optimizer.
@@ -546,8 +548,11 @@
no_forward_run = options.no_forward_run
ngpu = options.ngpu
use_wandb = options.use_wandb
+ bias_grad_times = options.bias_grad_times
distributed = distributed_option.distributed
+ if bias_grad_times != 1.0:
+ logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
if log_interval is None:
try:
log_interval = max(len(iterator) // 20, 10)
@@ -690,6 +695,16 @@
scale_factor=0.55,
)
+ # for contextual training
+ if bias_grad_times != 1.0:
+ # contextual related parameter names
+ cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
+ for name, param in model.named_parameters():
+ for cr_pname in cr_pnames:
+ if cr_pname in name:
+ param.grad *= bias_grad_times
+ continue
+
# compute the gradient norm to check if it is normal or not
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
--
Gitblit v1.9.1