From 54931dd4e1a099d7d6f144c4e12e5453deb3aa26 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 28 六月 2023 10:41:57 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 funasr/bin/asr_infer.py |  119 ++++++++++++++++++++++++++++++++++++++++-------------------
 1 files changed, 80 insertions(+), 39 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 288034c..a537a73 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -24,7 +24,7 @@
 from packaging.version import parse as V
 from typeguard import check_argument_types
 from typeguard import check_return_type
-
+from  funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
 from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -35,9 +35,7 @@
 from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
 from funasr.modules.scorers.ctc import CTCPrefixScorer
 from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.asr import frontend_choices
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_asr_model import frontend_choices
 from funasr.text.build_tokenizer import build_tokenizer
 from funasr.text.token_id_converter import TokenIDConverter
 from funasr.torch_utils.device_funcs import to_device
@@ -84,7 +82,7 @@
 
         # 1. Build ASR model
         scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
+        asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
@@ -92,7 +90,6 @@
             if asr_train_args.frontend == 'wav_frontend':
                 frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
             else:
-                from funasr.tasks.asr import frontend_choices
                 frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()
 
@@ -112,7 +109,7 @@
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
+            lm, lm_train_args = build_model_from_file(
                 lm_train_config, lm_file, None, device
             )
             scorers["lm"] = lm.lm
@@ -295,9 +292,8 @@
 
         # 1. Build ASR model
         scorers = {}
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        asr_model, asr_train_args = build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -319,8 +315,8 @@
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+            lm, lm_train_args = build_model_from_file(
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             scorers["lm"] = lm.lm
 
@@ -381,6 +377,7 @@
         self.asr_train_args = asr_train_args
         self.converter = converter
         self.tokenizer = tokenizer
+        self.cmvn_file = cmvn_file
 
         # 6. [Optional] Build hotword list from str, local file or url
         self.hotword_list = None
@@ -523,6 +520,44 @@
         return results
 
     def generate_hotwords_list(self, hotword_list_or_file):
+        def load_seg_dict(seg_dict_file):
+            seg_dict = {}
+            assert isinstance(seg_dict_file, str)
+            with open(seg_dict_file, "r", encoding="utf8") as f:
+                lines = f.readlines()
+                for line in lines:
+                    s = line.strip().split()
+                    key = s[0]
+                    value = s[1:]
+                    seg_dict[key] = " ".join(value)
+            return seg_dict
+
+        def seg_tokenize(txt, seg_dict):
+            pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+            out_txt = ""
+            for word in txt:
+                word = word.lower()
+                if word in seg_dict:
+                    out_txt += seg_dict[word] + " "
+                else:
+                    if pattern.match(word):
+                        for char in word:
+                            if char in seg_dict:
+                                out_txt += seg_dict[char] + " "
+                            else:
+                                out_txt += "<unk>" + " "
+                    else:
+                        out_txt += "<unk>" + " "
+            return out_txt.strip().split()
+
+        seg_dict = None
+        if self.cmvn_file is not None:
+            model_dir = os.path.dirname(self.cmvn_file)
+            seg_dict_file = os.path.join(model_dir, 'seg_dict')
+            if os.path.exists(seg_dict_file):
+                seg_dict = load_seg_dict(seg_dict_file)
+            else:
+                seg_dict = None
         # for None
         if hotword_list_or_file is None:
             hotword_list = None
@@ -534,8 +569,11 @@
             with codecs.open(hotword_list_or_file, 'r') as fin:
                 for line in fin.readlines():
                     hw = line.strip()
+                    hw_list = hw.split()
+                    if seg_dict is not None:
+                        hw_list = seg_tokenize(hw_list, seg_dict)
                     hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                    hotword_list.append(self.converter.tokens2ids(hw_list))
                 hotword_list.append([self.asr_model.sos])
                 hotword_str_list.append('<s>')
             logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -555,8 +593,11 @@
             with codecs.open(hotword_list_or_file, 'r') as fin:
                 for line in fin.readlines():
                     hw = line.strip()
+                    hw_list = hw.split()
+                    if seg_dict is not None:
+                        hw_list = seg_tokenize(hw_list, seg_dict)
                     hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                    hotword_list.append(self.converter.tokens2ids(hw_list))
                 hotword_list.append([self.asr_model.sos])
                 hotword_str_list.append('<s>')
             logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -568,7 +609,10 @@
             hotword_str_list = []
             for hw in hotword_list_or_file.strip().split():
                 hotword_str_list.append(hw)
-                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hw_list = hw.strip().split()
+                if seg_dict is not None:
+                    hw_list = seg_tokenize(hw_list, seg_dict)
+                hotword_list.append(self.converter.tokens2ids(hw_list))
             hotword_list.append([self.asr_model.sos])
             hotword_str_list.append('<s>')
             logging.info("Hotword list: {}.".format(hotword_str_list))
@@ -616,9 +660,8 @@
 
         # 1. Build ASR model
         scorers = {}
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        asr_model, asr_train_args = build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -640,8 +683,8 @@
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+            lm, lm_train_args = build_model_from_file(
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             scorers["lm"] = lm.lm
 
@@ -873,9 +916,8 @@
 
         # 1. Build ASR model
         scorers = {}
-        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        asr_model, asr_train_args = build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -901,8 +943,8 @@
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+            lm, lm_train_args = build_model_from_file(
+                lm_train_config, lm_file, device, "lm"
             )
             scorers["lm"] = lm.lm
 
@@ -1104,9 +1146,8 @@
         assert check_argument_types()
 
         # 1. Build ASR model
-        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
         scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
+        asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
 
@@ -1126,8 +1167,8 @@
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+            lm, lm_train_args = build_model_from_file(
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             lm.to(device)
             scorers["lm"] = lm.lm
@@ -1315,8 +1356,7 @@
         super().__init__()
 
         assert check_argument_types()
-        from funasr.tasks.asr import ASRTransducerTask
-        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+        asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
 
@@ -1350,8 +1390,8 @@
             asr_model.to(dtype=getattr(torch, dtype)).eval()
 
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+            lm, lm_train_args = build_model_from_file(
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             lm_scorer = lm.lm
         else:
@@ -1638,15 +1678,16 @@
         assert check_argument_types()
 
         # 1. Build ASR model
-        from funasr.tasks.sa_asr import ASRTask
         scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
+        asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            if asr_train_args.frontend == 'wav_frontend':
-                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            from funasr.tasks.sa_asr import frontend_choices
+            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
             else:
                 frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -1667,8 +1708,8 @@
 
         # 2. Build Language model
         if lm_train_config is not None:
-            lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, None, device
+            lm, lm_train_args = build_model_from_file(
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             scorers["lm"] = lm.lm
 

--
Gitblit v1.9.1