From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky

---
 funasr/bin/asr_inference_paraformer_vad_punc.py |  226 +++++++++++++++++++-------------------------------------
 1 files changed, 77 insertions(+), 149 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index a7ccfc2..9dc0b79 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -43,11 +43,11 @@
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
+from funasr.bin.vad_inference import Speech2VadSegment
+from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
 from funasr.bin.punctuation_infer import Text2Punc
 from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
 
-from FunASR.funasr.utils.timestamp_tools import time_stamp_sentence
 
 header_colors = '\033[95m'
 end_colors = '\033[0m'
@@ -58,7 +58,7 @@
 
     Examples:
             >>> import soundfile
-            >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+            >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
             >>> audio, rate = soundfile.read("speech.wav")
             >>> speech2text(audio)
             [(text, token, token_int, hypothesis object), ...]
@@ -178,55 +178,8 @@
         self.tokenizer = tokenizer
 
         # 6. [Optional] Build hotword list from str, local file or url
-        # for None
-        if hotword_list_or_file is None:
-            self.hotword_list = None
-        # for text str input
-        elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
-            logging.info("Attempting to parse hotwords as str...")
-            self.hotword_list = []
-            hotword_str_list = []
-            for hw in hotword_list_or_file.strip().split():
-                hotword_str_list.append(hw)
-                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-            self.hotword_list.append([self.asr_model.sos])
-            hotword_str_list.append('<s>')
-            logging.info("Hotword list: {}.".format(hotword_str_list))
-        # for local txt inputs
-        elif os.path.exists(hotword_list_or_file):
-            logging.info("Attempting to parse hotwords from local txt...")
-            self.hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hotword_str_list.append(hw)
-                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-                self.hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                .format(hotword_list_or_file, hotword_str_list))
-        # for url, download and generate txt
-        else:
-            logging.info("Attempting to parse hotwords from url...")
-            work_dir = tempfile.TemporaryDirectory().name
-            if not os.path.exists(work_dir):
-                os.makedirs(work_dir)
-            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
-            local_file = requests.get(hotword_list_or_file)
-            open(text_file_path, "wb").write(local_file.content)
-            hotword_list_or_file = text_file_path
-            self.hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hotword_str_list.append(hw)
-                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-                self.hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                .format(hotword_list_or_file, hotword_str_list))
+        self.hotword_list = None
+        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
 
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -303,7 +256,7 @@
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         if isinstance(self.asr_model, BiCifParaformer):
-            _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                    pre_token_length)  # test no bias cif2
 
         results = []
@@ -339,6 +292,8 @@
 
                 # remove blank symbol id, which is assumed to be 0
                 token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+                if len(token_int) == 0:
+                    continue
 
                 # Change integer-ids to tokens
                 token = self.converter.ids2tokens(token_int)
@@ -349,7 +304,10 @@
                     text = None
 
                 if isinstance(self.asr_model, BiCifParaformer):
-                    timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], 
+                                                            us_peaks[i], 
+                                                            copy.copy(token), 
+                                                            vad_offset=begin_time)
                     results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
                 else:
                     results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
@@ -357,101 +315,59 @@
         # assert check_return_type(results)
         return results
 
-
-class Speech2VadSegment:
-    """Speech2VadSegment class
-
-    Examples:
-        >>> import soundfile
-        >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
-        >>> audio, rate = soundfile.read("speech.wav")
-        >>> speech2segment(audio)
-        [[10, 230], [245, 450], ...]
-
-    """
-
-    def __init__(
-            self,
-            vad_infer_config: Union[Path, str] = None,
-            vad_model_file: Union[Path, str] = None,
-            vad_cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
-            batch_size: int = 1,
-            dtype: str = "float32",
-            **kwargs,
-    ):
-        assert check_argument_types()
-
-        # 1. Build vad model
-        vad_model, vad_infer_args = VADTask.build_model_from_file(
-            vad_infer_config, vad_model_file, device
-        )
-        frontend = None
-        if vad_infer_args.frontend is not None:
-            frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
-
-        # logging.info("vad_model: {}".format(vad_model))
-        # logging.info("vad_infer_args: {}".format(vad_infer_args))
-        vad_model.to(dtype=getattr(torch, dtype)).eval()
-
-        self.vad_model = vad_model
-        self.vad_infer_args = vad_infer_args
-        self.device = device
-        self.dtype = dtype
-        self.frontend = frontend
-        self.batch_size = batch_size
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ) -> List[List[int]]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-        assert check_argument_types()
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            self.frontend.filter_length_max = math.inf
-            fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
-            feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
-            fbanks = to_device(fbanks, device=self.device)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
+    def generate_hotwords_list(self, hotword_list_or_file):
+        # for None
+        if hotword_list_or_file is None:
+            hotword_list = None
+        # for local txt inputs
+        elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+            logging.info("Attempting to parse hotwords from local txt...")
+            hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                         .format(hotword_list_or_file, hotword_str_list))
+        # for url, download and generate txt
+        elif hotword_list_or_file.startswith('http'):
+            logging.info("Attempting to parse hotwords from url...")
+            work_dir = tempfile.TemporaryDirectory().name
+            if not os.path.exists(work_dir):
+                os.makedirs(work_dir)
+            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+            local_file = requests.get(hotword_list_or_file)
+            open(text_file_path, "wb").write(local_file.content)
+            hotword_list_or_file = text_file_path
+            hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                         .format(hotword_list_or_file, hotword_str_list))
+        # for text str input
+        elif not hotword_list_or_file.endswith('.txt'):
+            logging.info("Attempting to parse hotwords as str...")
+            hotword_list = []
+            hotword_str_list = []
+            for hw in hotword_list_or_file.strip().split():
+                hotword_str_list.append(hw)
+                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+            hotword_list.append([self.asr_model.sos])
+            hotword_str_list.append('<s>')
+            logging.info("Hotword list: {}.".format(hotword_str_list))
         else:
-            raise Exception("Need to extract feats first, please configure frontend configuration")
-
-        # b. Forward Encoder streaming
-        t_offset = 0
-        step = min(feats_len, 6000)
-        segments = [[]] * self.batch_size
-        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
-            if t_offset + step >= feats_len - 1:
-                step = feats_len - t_offset
-                is_final_send = True
-            else:
-                is_final_send = False
-            batch = {
-                "feats": feats[:, t_offset:t_offset + step, :],
-                "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
-                "is_final_send": is_final_send
-            }
-            # a. To device
-            batch = to_device(batch, device=self.device)
-            segments_part = self.vad_model(**batch)
-            if segments_part:
-                for batch_num in range(0, self.batch_size):
-                    segments[batch_num] += segments_part[batch_num]
-
-        return fbanks, segments
+            hotword_list = None
+        return hotword_list
 
 
 def inference(
@@ -639,7 +555,19 @@
                  output_dir_v2: Optional[str] = None,
                  fs: dict = None,
                  param_dict: dict = None,
+                 **kwargs,
                  ):
+
+        hotword_list_or_file = None
+        if param_dict is not None:
+            hotword_list_or_file = param_dict.get('hotword')
+
+        if 'hotword' in kwargs:
+            hotword_list_or_file = kwargs['hotword']
+
+        if speech2text.hotword_list is None:
+            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
         # 3. Build data-iterator
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
@@ -742,7 +670,7 @@
                     ibest_writer["token"][key] = " ".join(token)
                     ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                     ibest_writer["vad"][key] = "{}".format(vadsegments)
-                    ibest_writer["text"][key] = text_postprocessed
+                    ibest_writer["text"][key] = " ".join(word_lists)
                     ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                     if time_stamp_postprocessed is not None:
                         ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)

--
Gitblit v1.9.1