From 4e2fe544ae37174a3e09dfcdbbdae5abfe711e53 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 05 七月 2023 16:57:21 +0800
Subject: [PATCH] funasr sdk

---
 funasr/bin/asr_infer.py |  126 ++++++++++++++++++++++--------------------
 1 files changed, 66 insertions(+), 60 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 140b424..02ca63d 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -22,9 +22,7 @@
 import requests
 import torch
 from packaging.version import parse as V
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from  funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
 from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -78,7 +76,6 @@
             frontend_conf: dict = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -192,7 +189,6 @@
             text, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -248,7 +244,6 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-        assert check_return_type(results)
         return results
 
 
@@ -285,10 +280,10 @@
             nbest: int = 1,
             frontend_conf: dict = None,
             hotword_list_or_file: str = None,
+            clas_scale: float = 1.0,
             decoding_ind: int = 0,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -377,10 +372,12 @@
         self.asr_train_args = asr_train_args
         self.converter = converter
         self.tokenizer = tokenizer
+        self.cmvn_file = cmvn_file
 
         # 6. [Optional] Build hotword list from str, local file or url
         self.hotword_list = None
         self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
+        self.clas_scale = clas_scale
 
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -412,7 +409,6 @@
                 text, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -445,16 +441,20 @@
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
-                                                                                   NeatContextualParaformer):
+        if not isinstance(self.asr_model, ContextualParaformer) and \
+            not isinstance(self.asr_model, NeatContextualParaformer):
             if self.hotword_list:
                 logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
             decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                      pre_token_length)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
         else:
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
-                                                                     pre_token_length, hw_list=self.hotword_list)
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, 
+                                                                     enc_len, 
+                                                                     pre_acoustic_embeds,
+                                                                     pre_token_length, 
+                                                                     hw_list=self.hotword_list,
+                                                                     clas_scale=self.clas_scale)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         if isinstance(self.asr_model, BiCifParaformer):
@@ -515,10 +515,47 @@
                                                                vad_offset=begin_time)
                 results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
 
-        # assert check_return_type(results)
         return results
 
     def generate_hotwords_list(self, hotword_list_or_file):
+        def load_seg_dict(seg_dict_file):
+            seg_dict = {}
+            assert isinstance(seg_dict_file, str)
+            with open(seg_dict_file, "r", encoding="utf8") as f:
+                lines = f.readlines()
+                for line in lines:
+                    s = line.strip().split()
+                    key = s[0]
+                    value = s[1:]
+                    seg_dict[key] = " ".join(value)
+            return seg_dict
+
+        def seg_tokenize(txt, seg_dict):
+            pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+            out_txt = ""
+            for word in txt:
+                word = word.lower()
+                if word in seg_dict:
+                    out_txt += seg_dict[word] + " "
+                else:
+                    if pattern.match(word):
+                        for char in word:
+                            if char in seg_dict:
+                                out_txt += seg_dict[char] + " "
+                            else:
+                                out_txt += "<unk>" + " "
+                    else:
+                        out_txt += "<unk>" + " "
+            return out_txt.strip().split()
+
+        seg_dict = None
+        if self.cmvn_file is not None:
+            model_dir = os.path.dirname(self.cmvn_file)
+            seg_dict_file = os.path.join(model_dir, 'seg_dict')
+            if os.path.exists(seg_dict_file):
+                seg_dict = load_seg_dict(seg_dict_file)
+            else:
+                seg_dict = None
         # for None
         if hotword_list_or_file is None:
             hotword_list = None
@@ -530,8 +567,11 @@
             with codecs.open(hotword_list_or_file, 'r') as fin:
                 for line in fin.readlines():
                     hw = line.strip()
+                    hw_list = hw.split()
+                    if seg_dict is not None:
+                        hw_list = seg_tokenize(hw_list, seg_dict)
                     hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                    hotword_list.append(self.converter.tokens2ids(hw_list))
                 hotword_list.append([self.asr_model.sos])
                 hotword_str_list.append('<s>')
             logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -551,8 +591,11 @@
             with codecs.open(hotword_list_or_file, 'r') as fin:
                 for line in fin.readlines():
                     hw = line.strip()
+                    hw_list = hw.split()
+                    if seg_dict is not None:
+                        hw_list = seg_tokenize(hw_list, seg_dict)
                     hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                    hotword_list.append(self.converter.tokens2ids(hw_list))
                 hotword_list.append([self.asr_model.sos])
                 hotword_str_list.append('<s>')
             logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -564,7 +607,10 @@
             hotword_str_list = []
             for hw in hotword_list_or_file.strip().split():
                 hotword_str_list.append(hw)
-                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hw_list = hw.strip().split()
+                if seg_dict is not None:
+                    hw_list = seg_tokenize(hw_list, seg_dict)
+                hotword_list.append(self.converter.tokens2ids(hw_list))
             hotword_list.append([self.asr_model.sos])
             hotword_str_list.append('<s>')
             logging.info("Hotword list: {}.".format(hotword_str_list))
@@ -608,7 +654,6 @@
             hotword_list_or_file: str = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -728,7 +773,6 @@
                 text, token, token_int, hyp
 
         """
-        assert check_argument_types()
         results = []
         cache_en = cache["encoder"]
         if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@@ -823,7 +867,6 @@
 
                 results.append(postprocessed_result)
 
-        # assert check_return_type(results)
         return results
 
 
@@ -864,7 +907,6 @@
             frontend_conf: dict = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -988,7 +1030,6 @@
             text, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -1056,7 +1097,6 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-        assert check_return_type(results)
         return results
 
 
@@ -1095,7 +1135,6 @@
             streaming: bool = False,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -1200,7 +1239,6 @@
             text, token, token_int, hyp
 
         """
-        assert check_argument_types()
         # Input as audio signal
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
@@ -1250,7 +1288,6 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-        assert check_return_type(results)
         return results
 
 
@@ -1307,7 +1344,6 @@
         """Construct a Speech2Text object."""
         super().__init__()
 
-        assert check_argument_types()
         asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
@@ -1486,7 +1522,6 @@
         Returns:
             nbest_hypothesis: N-best hypothesis.
         """
-        assert check_argument_types()
 
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
@@ -1518,7 +1553,6 @@
         Returns:
             nbest_hypothesis: N-best hypothesis.
         """
-        assert check_argument_types()
 
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
@@ -1560,35 +1594,8 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-            assert check_return_type(results)
 
         return results
-
-    @staticmethod
-    def from_pretrained(
-            model_tag: Optional[str] = None,
-            **kwargs: Optional[Any],
-    ) -> Speech2Text:
-        """Build Speech2Text instance from the pretrained model.
-        Args:
-            model_tag: Model tag of the pretrained models.
-        Return:
-            : Speech2Text instance.
-        """
-        if model_tag is not None:
-            try:
-                from espnet_model_zoo.downloader import ModelDownloader
-
-            except ImportError:
-                logging.error(
-                    "`espnet_model_zoo` is not installed. "
-                    "Please install via `pip install -U espnet_model_zoo`."
-                )
-                raise
-            d = ModelDownloader()
-            kwargs.update(**d.download_and_unpack(model_tag))
-
-        return Speech2TextTransducer(**kwargs)
 
 
 class Speech2TextSAASR:
@@ -1627,7 +1634,6 @@
             frontend_conf: dict = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -1636,8 +1642,10 @@
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            if asr_train_args.frontend == 'wav_frontend':
-                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            from funasr.tasks.sa_asr import frontend_choices
+            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
             else:
                 frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -1743,7 +1751,6 @@
             text, text_id, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -1836,5 +1843,4 @@
 
             results.append((text, text_id, token, token_int, hyp))
 
-        assert check_return_type(results)
         return results

--
Gitblit v1.9.1