From ec3ccbea9ff1d869becaa2b13255d0da1e4bf3ca Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 02 三月 2023 20:23:39 +0800
Subject: [PATCH] torchscripts

---
 funasr/bin/asr_inference_paraformer.py |  120 ++++++++++++++++++++++++++++++++++--------------------------
 1 files changed, 68 insertions(+), 52 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 055a17f..b807a34 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -169,56 +169,8 @@
         self.tokenizer = tokenizer
 
         # 6. [Optional] Build hotword list from str, local file or url
-        # for None 
-        if hotword_list_or_file is None:
-            self.hotword_list = None
-        # for text str input
-        elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
-            logging.info("Attempting to parse hotwords as str...")
-            self.hotword_list = []
-            hotword_str_list = []
-            for hw in hotword_list_or_file.strip().split():
-                hotword_str_list.append(hw)
-                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-            self.hotword_list.append([self.asr_model.sos])
-            hotword_str_list.append('<s>')
-            logging.info("Hotword list: {}.".format(hotword_str_list))
-        # for local txt inputs
-        elif os.path.exists(hotword_list_or_file):
-            logging.info("Attempting to parse hotwords from local txt...")
-            self.hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hotword_str_list.append(hw)
-                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-                self.hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                .format(hotword_list_or_file, hotword_str_list))
-        # for url, download and generate txt
-        else:
-            logging.info("Attempting to parse hotwords from url...")
-            work_dir = tempfile.TemporaryDirectory().name
-            if not os.path.exists(work_dir):
-                os.makedirs(work_dir)
-            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
-            local_file = requests.get(hotword_list_or_file)
-            open(text_file_path, "wb").write(local_file.content)
-            hotword_list_or_file = text_file_path
-            self.hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hotword_str_list.append(hw)
-                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-                self.hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                .format(hotword_list_or_file, hotword_str_list))
-
+        self.hotword_list = None
+        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
 
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -337,7 +289,61 @@
         # assert check_return_type(results)
         return results
 
-class Speech2TextExport(torch.nn.Module):
+    def generate_hotwords_list(self, hotword_list_or_file):
+        # for None
+        if hotword_list_or_file is None:
+            hotword_list = None
+        # for local txt inputs
+        elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+            logging.info("Attempting to parse hotwords from local txt...")
+            hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                         .format(hotword_list_or_file, hotword_str_list))
+        # for url, download and generate txt
+        elif hotword_list_or_file.startswith('http'):
+            logging.info("Attempting to parse hotwords from url...")
+            work_dir = tempfile.TemporaryDirectory().name
+            if not os.path.exists(work_dir):
+                os.makedirs(work_dir)
+            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+            local_file = requests.get(hotword_list_or_file)
+            open(text_file_path, "wb").write(local_file.content)
+            hotword_list_or_file = text_file_path
+            hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                         .format(hotword_list_or_file, hotword_str_list))
+        # for text str input
+        elif not hotword_list_or_file.endswith('.txt'):
+            logging.info("Attempting to parse hotwords as str...")
+            hotword_list = []
+            hotword_str_list = []
+            for hw in hotword_list_or_file.strip().split():
+                hotword_str_list.append(hw)
+                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+            hotword_list.append([self.asr_model.sos])
+            hotword_str_list.append('<s>')
+            logging.info("Hotword list: {}.".format(hotword_str_list))
+        else:
+            hotword_list = None
+        return hotword_list
+
+class Speech2TextExport:
     """Speech2TextExport class
 
     """
@@ -416,7 +422,7 @@
         self.asr_model = model
         
     @torch.no_grad()
-    def forward(
+    def __call__(
             self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
     ):
         """Inference
@@ -648,7 +654,17 @@
             output_dir_v2: Optional[str] = None,
             fs: dict = None,
             param_dict: dict = None,
+            **kwargs,
     ):
+
+        hotword_list_or_file = None
+        if param_dict is not None:
+            hotword_list_or_file = param_dict.get('hotword')
+        if 'hotword' in kwargs:
+            hotword_list_or_file = kwargs['hotword']
+        if hotword_list_or_file is not None or 'hotword' in kwargs:
+            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
         # 3. Build data-iterator
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):

--
Gitblit v1.9.1