From c20c871e9f963151fa410dd616c6b23d001ecdd2 Mon Sep 17 00:00:00 2001
From: Xian Shi <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 04 七月 2023 19:57:04 +0800
Subject: [PATCH] Merge pull request #673 from alibaba-damo-academy/dev_clas

---
 funasr/bin/asr_infer.py                                                                                     |   14 +++++--
 funasr/datasets/large_datasets/dataset.py                                                                   |   18 +++++----
 funasr/models/decoder/contextual_decoder.py                                                                 |    3 +
 funasr/models/e2e_asr_contextual_paraformer.py                                                              |    4 +-
 funasr/bin/build_trainer.py                                                                                 |    4 +
 funasr/datasets/large_datasets/utils/hotword_utils.py                                                       |    3 +
 funasr/datasets/large_datasets/utils/tokenize.py                                                            |   12 +++++
 egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/demo.py |    4 ++
 funasr/bin/asr_inference_launch.py                                                                          |    2 +
 9 files changed, 46 insertions(+), 18 deletions(-)

diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/demo.py b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/demo.py
index bec6f05..e5e9097 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/demo.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/demo.py
@@ -3,6 +3,10 @@
 
 param_dict = dict()
 param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
+param_dict['clas_scale'] = 1.00  # 1.50 # set it larger if you want high recall (sacrifice general accuracy)
+# 13% relative recall raise over internal hotword test set (45%->51%)
+# CER might raise when utterance contains no hotword
+
 inference_pipeline = pipeline(
     task=Tasks.auto_speech_recognition,
     model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 259a286..02ca63d 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -280,6 +280,7 @@
             nbest: int = 1,
             frontend_conf: dict = None,
             hotword_list_or_file: str = None,
+            clas_scale: float = 1.0,
             decoding_ind: int = 0,
             **kwargs,
     ):
@@ -376,6 +377,7 @@
         # 6. [Optional] Build hotword list from str, local file or url
         self.hotword_list = None
         self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
+        self.clas_scale = clas_scale
 
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -439,16 +441,20 @@
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
-                                                                                   NeatContextualParaformer):
+        if not isinstance(self.asr_model, ContextualParaformer) and \
+            not isinstance(self.asr_model, NeatContextualParaformer):
             if self.hotword_list:
                 logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
             decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                      pre_token_length)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
         else:
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
-                                                                     pre_token_length, hw_list=self.hotword_list)
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, 
+                                                                     enc_len, 
+                                                                     pre_acoustic_embeds,
+                                                                     pre_token_length, 
+                                                                     hw_list=self.hotword_list,
+                                                                     clas_scale=self.clas_scale)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         if isinstance(self.asr_model, BiCifParaformer):
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 81513ae..a752f29 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -257,6 +257,7 @@
         export_mode = param_dict.get("export_mode", False)
     else:
         hotword_list_or_file = None
+    clas_scale = param_dict.get('clas_scale', 1.0)
 
     if kwargs.get("device", None) == "cpu":
         ngpu = 0
@@ -289,6 +290,7 @@
         penalty=penalty,
         nbest=nbest,
         hotword_list_or_file=hotword_list_or_file,
+        clas_scale=clas_scale,
     )
 
     speech2text = Speech2TextParaformer(**speech2text_kwargs)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 267e405..891139a 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -85,7 +85,9 @@
         finetune_configs = yaml.safe_load(f)
         # set data_types
         if dataset_type == "large":
-            finetune_configs["dataset_conf"]["data_types"] = "sound,text"
+            # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
+            if 'data_types' not in finetune_configs['dataset_conf']:
+                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
     finetune_configs = update_dct(configs, finetune_configs)
     for key, value in finetune_configs.items():
         if hasattr(args, key):
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 5f2c2c6..1e9bb26 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -202,14 +202,7 @@
     data_types = conf.get("data_types", "kaldi_ark,text")
 
     pre_hwfile = conf.get("pre_hwlist", None)
-    pre_prob = conf.get("pre_prob", 0)  # unused yet
-
-    hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
-                 "double_rate": conf.get("double_rate", 0.1),
-                 "hotword_min_length": conf.get("hotword_min_length", 2),
-                 "hotword_max_length": conf.get("hotword_max_length", 8),
-                 "pre_prob": conf.get("pre_prob", 0.0)}
-
+    # pre_prob = conf.get("pre_prob", 0)  # unused yet
     if pre_hwfile is not None:
         pre_hwlist = []
         with open(pre_hwfile, 'r') as fin:
@@ -218,6 +211,15 @@
     else:
         pre_hwlist = None
 
+    hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
+                 "double_rate": conf.get("double_rate", 0.1),
+                 "hotword_min_length": conf.get("hotword_min_length", 2),
+                 "hotword_max_length": conf.get("hotword_max_length", 8),
+                 "pre_prob": conf.get("pre_prob", 0.0),
+                 "pre_hwlist": pre_hwlist}
+
+    
+
     dataset = AudioDataset(scp_lists, 
                            data_names, 
                            data_types, 
diff --git a/funasr/datasets/large_datasets/utils/hotword_utils.py b/funasr/datasets/large_datasets/utils/hotword_utils.py
index fccfea6..73f8bdd 100644
--- a/funasr/datasets/large_datasets/utils/hotword_utils.py
+++ b/funasr/datasets/large_datasets/utils/hotword_utils.py
@@ -6,7 +6,8 @@
                    sample_rate,
                    double_rate,
                    pre_prob,
-                   pre_index=None):
+                   pre_index=None,
+                   pre_hwlist=None):
         if length < hotword_min_length:
             return [-1]
         if random.random() < sample_rate:
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index a7eb6d2..c16e1dc 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -54,7 +54,17 @@
 
     length = len(text)
     if 'hw_tag' in data:
-        hotword_indxs = sample_hotword(length, **hw_config)
+        if hw_config['pre_hwlist'] is not None and hw_config['pre_prob'] > 0:
+            # enable preset hotword detect in sampling
+            pre_index = None
+            for hw in hw_config['pre_hwlist']:
+                hw = " ".join(seg_tokenize(hw, seg_dict))
+                _find = " ".join(text).find(hw)
+                if _find != -1:
+                    # _find = text[:_find].count(" ")  # bpe sometimes
+                    pre_index = [_find, _find + max(hw.count(" "), 1)]
+                    break
+        hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
         data['hotword_indxs'] = hotword_indxs
         del data['hw_tag']
     for i in range(length):
diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py
index 0e69c44..c94e179 100644
--- a/funasr/models/decoder/contextual_decoder.py
+++ b/funasr/models/decoder/contextual_decoder.py
@@ -244,6 +244,7 @@
         ys_in_pad: torch.Tensor,
         ys_in_lens: torch.Tensor,
         contextual_info: torch.Tensor,
+        clas_scale: float = 1.0,
         return_hidden: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
@@ -283,7 +284,7 @@
         cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
 
         if self.bias_output is not None:
-            x = torch.cat([x_src_attn, cx], dim=2)
+            x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
             x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
             x = x_self_attn + self.dropout(x)
 
diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py
index 4836663..d27fd8d 100644
--- a/funasr/models/e2e_asr_contextual_paraformer.py
+++ b/funasr/models/e2e_asr_contextual_paraformer.py
@@ -341,7 +341,7 @@
             input_mask_expand_dim, 0)
         return sematic_embeds * tgt_mask, decoder_out * tgt_mask
 
-    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
         if hw_list is None:
             hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
             hw_list_pad = pad_list(hw_list, 0)
@@ -363,7 +363,7 @@
             hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
         
         decoder_outs = self.decoder(
-            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
         )
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)

--
Gitblit v1.9.1