游雁
2023-12-27 f6b611de44c3a535befa96da552d07b0ed1b073c
funasr/bin/inference.py
@@ -18,6 +18,7 @@
from funasr.register import tables
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.timestamp_tools import time_stamp_sentence
def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
   """
@@ -46,7 +47,7 @@
                  data = lines["source"]
                  key = data["key"] if "key" in data else key
               else: # filelist, wav.scp, text.txt: id \t data or data
                  lines = line.strip().split()
                  lines = line.strip().split(maxsplit=1)
                  data = lines[1] if len(lines)>1 else lines[0]
                  key = lines[0] if len(lines)>1 else key
               
@@ -227,6 +228,7 @@
      # step.1: compute the vad model
      model = self.vad_model
      kwargs = self.vad_kwargs
      kwargs.update(cfg)
      beg_vad = time.time()
      res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
      end_vad = time.time()
@@ -322,6 +324,23 @@
         result["key"] = key
         results_ret_list.append(result)
         pbar_total.update(1)
      # step.3 compute punc model
      model = self.punc_model
      kwargs = self.punc_kwargs
      kwargs.update(cfg)
      for i, result in enumerate(results_ret_list):
         beg_punc = time.time()
         res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg)
         end_punc = time.time()
         print(f"time punc: {end_punc - beg_punc:0.3f}")
         # sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"])
         # results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"]
         # results_ret_list[i]["sentences"] = sentences
         # results_ret_list[i]["text_with_punc"] = res[i]["text"]
      pbar_total.update(1)
      end_total = time.time()
      time_escape_total_all_samples = end_total - beg_total