zhifu gao
2024-03-11 a7d7a0f3a2e7cd44a337ced34e3536b12ccb534e
funasr/models/fsmn_vad_streaming/model.py
@@ -15,7 +15,7 @@
from typing import List, Tuple, Dict, Any, Optional
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
class VadStateMachine(Enum):
@@ -23,10 +23,12 @@
   kVadInStateInSpeechSegment = 2
   kVadInStateEndPointDetected = 3
class FrameState(Enum):
   kFrameStateInvalid = -1
   kFrameStateSpeech = 1
   kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
@@ -37,9 +39,11 @@
   kChangeStateNoBegin = 4
   kChangeStateInvalid = 5
class VadDetectMode(Enum):
   kVadSingleUtteranceDetectMode = 0
   kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
   """
@@ -47,6 +51,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(
      self,
      sample_rate: int = 16000,
@@ -117,6 +122,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self):
      self.start_ms = 0
      self.end_ms = 0
@@ -140,6 +146,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self):
      self.noise_prob = 0.0
      self.speech_prob = 0.0
@@ -154,6 +161,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self, window_size_ms: int,
                sil_to_speech_time: int,
                speech_to_sil_time: int,
@@ -190,7 +198,7 @@
   def GetWinSize(self) -> int:
      return int(self.win_size_frame)
   
   def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
   def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
      cur_frame_state = FrameState.kFrameStateSil
      if frameState == FrameState.kFrameStateSpeech:
         cur_frame_state = 1
@@ -220,13 +228,13 @@
   def FrameSizeMs(self) -> int:
      return int(self.frame_size_ms)
class Stats(object):
   def __init__(self,
                sil_pdf_ids,
                max_end_sil_frame_cnt_thresh,
                speech_noise_thres,
                ):
      self.data_buf_start_frame = 0
      self.frm_cnt = 0
      self.latest_confirmed_speech_frame = 0
@@ -263,6 +271,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self,
                encoder: str = None,
                encoder_conf: Optional[Dict] = None,
@@ -275,7 +284,7 @@
      encoder_class = tables.encoder_classes.get(encoder)
      encoder = encoder_class(**encoder_conf)
      self.encoder = encoder
      self.encoder_conf = encoder_conf
   
   def ResetDetection(self, cache: dict = {}):
      cache["stats"].continous_silence_frame_count = 0
@@ -293,7 +302,8 @@
         drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
         real_drop_frames = drop_frames - cache["stats"].last_drop_frames
         cache["stats"].last_drop_frames = drop_frames
         cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
         cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
            self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
         cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
         cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
   
@@ -301,7 +311,8 @@
      frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
      frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
      if cache["stats"].data_buf_all is None:
         cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
         cache["stats"].data_buf_all = cache["stats"].waveform[
            0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
         cache["stats"].data_buf = cache["stats"].data_buf_all
      else:
         cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
@@ -320,15 +331,16 @@
      else:
         cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
   
   def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
   def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
      while cache["stats"].data_buf_start_frame < frame_idx:
         if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
            cache["stats"].data_buf_start_frame += 1
            cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
               self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
            cache["stats"].data_buf = cache["stats"].data_buf_all[
                                      (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
                                         self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
   
   def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                          last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
                          last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
      self.PopDataBufTillFrame(start_frm, cache=cache)
      expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
      if last_frm_is_end_point:
@@ -380,14 +392,15 @@
      cache["stats"].lastest_confirmed_silence_frame = valid_frame
      if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
         self.PopDataBufTillFrame(valid_frame, cache=cache)
      # silence_detected_callback_
      # pass
   
   def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
   # silence_detected_callback_
   # pass
   def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
      cache["stats"].latest_confirmed_speech_frame = valid_frame
      self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
   
   def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
   def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
      if self.vad_opts.do_start_point_detection:
         pass
      if cache["stats"].confirmed_start_frame != -1:
@@ -398,7 +411,7 @@
      if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
         self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
   
   def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
   def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
      for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
         self.OnVoiceDetected(t, cache=cache)
      if self.vad_opts.do_end_point_detection:
@@ -470,13 +483,17 @@
      
      return frame_state
   
   def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
               is_final: bool = False
   def forward(self, feats: torch.Tensor,
               waveform: torch.tensor,
               cache: dict = {},
               is_final: bool = False,
               **kwargs,
               ):
      # if len(cache) == 0:
      #     self.AllResetDetection()
      # self.waveform = waveform  # compute decibel for each frame
      cache["stats"].waveform = waveform
      is_streaming_input = kwargs.get("is_streaming_input", True)
      self.ComputeDecibel(cache=cache)
      self.ComputeScores(feats, cache=cache)
      if not is_final:
@@ -488,12 +505,32 @@
         segment_batch = []
         if len(cache["stats"].output_data_buf) > 0:
            for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
               if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
                  i].contain_seg_end_point):
                  continue
               segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
               if is_streaming_input: # in this case, return [beg, -1], [], [-1, end], [beg, end]
                  if not cache["stats"].output_data_buf[i].contain_seg_start_point:
                     continue
                  if not cache["stats"].next_seg and not cache["stats"].output_data_buf[i].contain_seg_end_point:
                     continue
                  start_ms = cache["stats"].output_data_buf[i].start_ms if cache["stats"].next_seg else -1
                  if cache["stats"].output_data_buf[i].contain_seg_end_point:
                     end_ms = cache["stats"].output_data_buf[i].end_ms
                     cache["stats"].next_seg = True
                     cache["stats"].output_data_buf_offset += 1
                  else:
                     end_ms = -1
                     cache["stats"].next_seg = False
                  segment = [start_ms, end_ms]
               else: # in this case, return [beg, end]
                  if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
                  cache["stats"].output_data_buf[
                     i].contain_seg_end_point):
                     continue
                  segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
                  cache["stats"].output_data_buf_offset += 1  # need update this parameter
               segment_batch.append(segment)
               cache["stats"].output_data_buf_offset += 1  # need update this parameter
         if segment_batch:
            segments.append(segment_batch)
      # if is_final:
@@ -502,6 +539,7 @@
      return segments
   
   def init_cache(self, cache: dict = {}, **kwargs):
      cache["frontend"] = {}
      cache["prev_samples"] = torch.empty(0)
      cache["encoder"] = {}
@@ -533,11 +571,13 @@
         self.init_cache(cache, **kwargs)
      
      meta_data = {}
      chunk_size = kwargs.get("chunk_size", 60000) # 50ms
      chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
      chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
      
      time1 = time.perf_counter()
      cfg = {"is_final": kwargs.get("is_final", False)}
      is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
      is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
      cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
      audio_sample_list = load_audio_text_image_video(data_in,
                                                      fs=frontend.fs,
                                                      audio_fs=kwargs.get("fs", 16000),
@@ -546,7 +586,7 @@
                                                      cache=cfg,
                                                      )
      _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
      is_streaming_input = cfg["is_streaming_input"]
      time2 = time.perf_counter()
      meta_data["load_data"] = f"{time2 - time1:0.3f}"
      assert len(audio_sample_list) == 1, "batch_size must be set 1"
@@ -574,21 +614,22 @@
            "feats": speech,
            "waveform": cache["frontend"]["waveforms"],
            "is_final": kwargs["is_final"],
            "cache": cache
            "cache": cache,
            "is_streaming_input": is_streaming_input
         }
         segments_i = self.forward(**batch)
         if len(segments_i) > 0:
            segments.extend(*segments_i)
      
      cache["prev_samples"] = audio_sample[:-m]
      if _is_final:
         cache = {}
         self.init_cache(cache)
      
      ibest_writer = None
      if ibest_writer is None and kwargs.get("output_dir") is not None:
         writer = DatadirWriter(kwargs.get("output_dir"))
         ibest_writer = writer[f"{1}best_recog"]
      if kwargs.get("output_dir") is not None:
         if not hasattr(self, "writer"):
            self.writer = DatadirWriter(kwargs.get("output_dir"))
         ibest_writer = self.writer[f"{1}best_recog"]
      
      results = []
      result_i = {"key": key[0], "value": segments}
@@ -600,16 +641,59 @@
      if ibest_writer is not None:
         ibest_writer["text"][key[0]] = segments
      
      return results, meta_data
   
   def export(self, **kwargs):
      is_onnx = kwargs.get("type", "onnx") == "onnx"
      encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
      self.encoder = encoder_class(self.encoder, onnx=is_onnx)
      self.forward = self._export_forward
      return self
   def export_forward(self, feats: torch.Tensor, *args, **kwargs):
      scores, out_caches = self.encoder(feats, *args)
      return scores, out_caches
   def export_dummy_inputs(self, data_in=None, frame=30):
      if data_in is None:
         speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
      else:
         speech = None # Undo
      cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
      in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
      in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
      in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
      in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
      return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
   def export_input_names(self):
      return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
   def export_output_names(self):
      return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
   def export_dynamic_axes(self):
      return {
         'speech': {
            1: 'feats_length'
         },
      }
   def export_name(self, ):
      return "model.onnx"
   
   def DetectCommonFrames(self, cache: dict = {}) -> int:
      if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
         return 0
      for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
         frame_state = FrameState.kFrameStateInvalid
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
                                          cache=cache)
         self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
      
      return 0
@@ -619,7 +703,8 @@
         return 0
      for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
         frame_state = FrameState.kFrameStateInvalid
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
                                          cache=cache)
         if i != 0:
            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
         else:
@@ -627,7 +712,8 @@
      
      return 0
   
   def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
   def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
                      cache: dict = {}) -> None:
      tmp_cur_frm_state = FrameState.kFrameStateInvalid
      if cur_frm_state == FrameState.kFrameStateSpeech:
         if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
@@ -644,7 +730,8 @@
         cache["stats"].pre_end_silence_detected = False
         start_frame = 0
         if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
            start_frame = max(cache["stats"].data_buf_start_frame,
                              cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
            self.OnVoiceStart(start_frame, cache=cache)
            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
            for t in range(start_frame + 1, cur_frm_idx + 1):
@@ -696,7 +783,8 @@
         if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            # silence timeout, return zero length decision
            if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
               cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
               cache[
                  "stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
               or (is_final_frame and cache["stats"].number_end_time_detected == 0):
               for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
                  self.OnSilenceDetected(t, cache=cache)
@@ -707,7 +795,8 @@
               if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
                  self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
         elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
            if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
            if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
               "stats"].max_end_sil_frame_cnt_thresh:
               lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
               if self.vad_opts.do_extend:
                  lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
@@ -731,6 +820,5 @@
      if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
         self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
         self.ResetDetection(cache=cache)