From 4daea3711063c64485be3c00eaa9727404549f51 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 24 二月 2023 17:55:00 +0800
Subject: [PATCH] onnx

---
 funasr/bin/asr_inference_paraformer_vad_punc.py |  108 ++++++++++++++++++++++++++++++++++++++++++++++++++----
 1 files changed, 100 insertions(+), 8 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index ee36135..96f70ef 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -5,6 +5,10 @@
 import logging
 import sys
 import time
+import os
+import codecs
+import tempfile
+import requests
 from pathlib import Path
 from typing import Optional
 from typing import Sequence
@@ -39,9 +43,11 @@
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
+from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
 from funasr.bin.punctuation_infer import Text2Punc
-from funasr.models.e2e_asr_paraformer import BiCifParaformer
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+
+from funasr.utils.timestamp_tools import time_stamp_sentence
 
 header_colors = '\033[95m'
 end_colors = '\033[0m'
@@ -79,6 +85,7 @@
             penalty: float = 0.0,
             nbest: int = 1,
             frontend_conf: dict = None,
+            hotword_list_or_file: str = None,
             **kwargs,
     ):
         assert check_argument_types()
@@ -169,6 +176,11 @@
         self.asr_train_args = asr_train_args
         self.converter = converter
         self.tokenizer = tokenizer
+
+        # 6. [Optional] Build hotword list from str, local file or url
+        self.hotword_list = None
+        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
+
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
             beam_search = None
@@ -233,8 +245,15 @@
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
-        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+        if not isinstance(self.asr_model, ContextualParaformer):
+            if self.hotword_list:
+                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+        else:
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         if isinstance(self.asr_model, BiCifParaformer):
             _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
@@ -286,13 +305,64 @@
                     timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
                     results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
                 else:
-                    time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time,
-                                                 end_time)
-                    results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
+                    results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
 
         # assert check_return_type(results)
         return results
 
+    def generate_hotwords_list(self, hotword_list_or_file):
+        # for None
+        if hotword_list_or_file is None:
+            hotword_list = None
+        # for local txt inputs
+        elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+            logging.info("Attempting to parse hotwords from local txt...")
+            hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                         .format(hotword_list_or_file, hotword_str_list))
+        # for url, download and generate txt
+        elif hotword_list_or_file.startswith('http'):
+            logging.info("Attempting to parse hotwords from url...")
+            work_dir = tempfile.TemporaryDirectory().name
+            if not os.path.exists(work_dir):
+                os.makedirs(work_dir)
+            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+            local_file = requests.get(hotword_list_or_file)
+            open(text_file_path, "wb").write(local_file.content)
+            hotword_list_or_file = text_file_path
+            hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                         .format(hotword_list_or_file, hotword_str_list))
+        # for text str input
+        elif not hotword_list_or_file.endswith('.txt'):
+            logging.info("Attempting to parse hotwords as str...")
+            hotword_list = []
+            hotword_str_list = []
+            for hw in hotword_list_or_file.strip().split():
+                hotword_str_list.append(hw)
+                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+            hotword_list.append([self.asr_model.sos])
+            hotword_str_list.append('<s>')
+            logging.info("Hotword list: {}.".format(hotword_str_list))
+        else:
+            hotword_list = None
+        return hotword_list
 
 class Speech2VadSegment:
     """Speech2VadSegment class
@@ -515,6 +585,11 @@
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
 
+    if param_dict is not None:
+        hotword_list_or_file = param_dict.get('hotword')
+    else:
+        hotword_list_or_file = None
+
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
     else:
@@ -553,6 +628,7 @@
         ngram_weight=ngram_weight,
         penalty=penalty,
         nbest=nbest,
+        hotword_list_or_file=hotword_list_or_file,
     )
     speech2text = Speech2Text(**speech2text_kwargs)
     text2punc = None
@@ -569,7 +645,19 @@
                  output_dir_v2: Optional[str] = None,
                  fs: dict = None,
                  param_dict: dict = None,
+                 **kwargs,
                  ):
+
+        hotword_list_or_file = None
+        if param_dict is not None:
+            hotword_list_or_file = param_dict.get('hotword')
+
+        if 'hotword' in kwargs:
+            hotword_list_or_file = kwargs['hotword']
+
+        if speech2text.hotword_list is None:
+            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
         # 3. Build data-iterator
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
@@ -636,7 +724,8 @@
                 text, token, token_int = result[0], result[1], result[2]
                 time_stamp = None if len(result) < 4 else result[3]
 
-                if use_timestamp and time_stamp is not None:
+
+                if use_timestamp and time_stamp is not None: 
                     postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
                 else:
                     postprocessed_result = postprocess_utils.sentence_postprocess(token)
@@ -651,6 +740,7 @@
                     text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
 
                 text_postprocessed_punc = text_postprocessed
+                punc_id_list = []
                 if len(word_lists) > 0 and text2punc is not None:
                     text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
 
@@ -660,6 +750,8 @@
                 if time_stamp_postprocessed != "":
                     item['time_stamp'] = time_stamp_postprocessed
 
+                item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
+
                 asr_result_list.append(item)
                 finish_count += 1
                 # asr_utils.print_progress(finish_count / file_count)

--
Gitblit v1.9.1