游雁
2024-01-23 c892cc34a9e181e9ea7b4e59c35651a61149401f
funasr/models/fsmn_vad_streaming/model.py
@@ -15,7 +15,7 @@
from typing import List, Tuple, Dict, Any, Optional
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
class VadStateMachine(Enum):
@@ -23,10 +23,12 @@
   kVadInStateInSpeechSegment = 2
   kVadInStateEndPointDetected = 3
class FrameState(Enum):
   kFrameStateInvalid = -1
   kFrameStateSpeech = 1
   kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
@@ -37,9 +39,11 @@
   kChangeStateNoBegin = 4
   kChangeStateInvalid = 5
class VadDetectMode(Enum):
   kVadSingleUtteranceDetectMode = 0
   kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
   """
@@ -47,6 +51,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(
      self,
      sample_rate: int = 16000,
@@ -117,6 +122,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self):
      self.start_ms = 0
      self.end_ms = 0
@@ -140,6 +146,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self):
      self.noise_prob = 0.0
      self.speech_prob = 0.0
@@ -154,6 +161,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self, window_size_ms: int,
                sil_to_speech_time: int,
                speech_to_sil_time: int,
@@ -190,7 +198,7 @@
   def GetWinSize(self) -> int:
      return int(self.win_size_frame)
   
   def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
   def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
      cur_frame_state = FrameState.kFrameStateSil
      if frameState == FrameState.kFrameStateSpeech:
         cur_frame_state = 1
@@ -220,13 +228,13 @@
   def FrameSizeMs(self) -> int:
      return int(self.frame_size_ms)
class Stats(object):
   def __init__(self,
                sil_pdf_ids,
                max_end_sil_frame_cnt_thresh,
                speech_noise_thres,
                ):
      self.data_buf_start_frame = 0
      self.frm_cnt = 0
      self.latest_confirmed_speech_frame = 0
@@ -263,6 +271,7 @@
   Deep-FSMN for Large Vocabulary Continuous Speech Recognition
   https://arxiv.org/abs/1803.05030
   """
   def __init__(self,
                encoder: str = None,
                encoder_conf: Optional[Dict] = None,
@@ -275,7 +284,6 @@
      encoder_class = tables.encoder_classes.get(encoder)
      encoder = encoder_class(**encoder_conf)
      self.encoder = encoder
   
   def ResetDetection(self, cache: dict = {}):
      cache["stats"].continous_silence_frame_count = 0
@@ -293,7 +301,8 @@
         drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
         real_drop_frames = drop_frames - cache["stats"].last_drop_frames
         cache["stats"].last_drop_frames = drop_frames
         cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
         cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
            self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
         cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
         cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
   
@@ -301,7 +310,8 @@
      frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
      frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
      if cache["stats"].data_buf_all is None:
         cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
         cache["stats"].data_buf_all = cache["stats"].waveform[
            0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
         cache["stats"].data_buf = cache["stats"].data_buf_all
      else:
         cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
@@ -320,15 +330,16 @@
      else:
         cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
   
   def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
   def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
      while cache["stats"].data_buf_start_frame < frame_idx:
         if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
            cache["stats"].data_buf_start_frame += 1
            cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
               self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
            cache["stats"].data_buf = cache["stats"].data_buf_all[
                                      (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
                                         self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
   
   def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                          last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
                          last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
      self.PopDataBufTillFrame(start_frm, cache=cache)
      expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
      if last_frm_is_end_point:
@@ -380,14 +391,15 @@
      cache["stats"].lastest_confirmed_silence_frame = valid_frame
      if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
         self.PopDataBufTillFrame(valid_frame, cache=cache)
      # silence_detected_callback_
      # pass
   
   def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
   # silence_detected_callback_
   # pass
   def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
      cache["stats"].latest_confirmed_speech_frame = valid_frame
      self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
   
   def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
   def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
      if self.vad_opts.do_start_point_detection:
         pass
      if cache["stats"].confirmed_start_frame != -1:
@@ -398,7 +410,7 @@
      if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
         self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
   
   def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
   def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
      for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
         self.OnVoiceDetected(t, cache=cache)
      if self.vad_opts.do_end_point_detection:
@@ -488,7 +500,8 @@
         segment_batch = []
         if len(cache["stats"].output_data_buf) > 0:
            for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
               if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
               if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
               cache["stats"].output_data_buf[
                  i].contain_seg_end_point):
                  continue
               segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
@@ -502,6 +515,7 @@
      return segments
   
   def init_cache(self, cache: dict = {}, **kwargs):
      cache["frontend"] = {}
      cache["prev_samples"] = torch.empty(0)
      cache["encoder"] = {}
@@ -533,7 +547,7 @@
         self.init_cache(cache, **kwargs)
      
      meta_data = {}
      chunk_size = kwargs.get("chunk_size", 60000) # 50ms
      chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
      chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
      
      time1 = time.perf_counter()
@@ -580,10 +594,9 @@
         if len(segments_i) > 0:
            segments.extend(*segments_i)
      
      cache["prev_samples"] = audio_sample[:-m]
      if _is_final:
         cache = {}
         self.init_cache(cache)
      
      ibest_writer = None
      if ibest_writer is None and kwargs.get("output_dir") is not None:
@@ -600,16 +613,15 @@
      if ibest_writer is not None:
         ibest_writer["text"][key[0]] = segments
      
      return results, meta_data
   
   def DetectCommonFrames(self, cache: dict = {}) -> int:
      if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
         return 0
      for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
         frame_state = FrameState.kFrameStateInvalid
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
                                          cache=cache)
         self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
      
      return 0
@@ -619,7 +631,8 @@
         return 0
      for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
         frame_state = FrameState.kFrameStateInvalid
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
         frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
                                          cache=cache)
         if i != 0:
            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
         else:
@@ -627,7 +640,8 @@
      
      return 0
   
   def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
   def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
                      cache: dict = {}) -> None:
      tmp_cur_frm_state = FrameState.kFrameStateInvalid
      if cur_frm_state == FrameState.kFrameStateSpeech:
         if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
@@ -644,7 +658,8 @@
         cache["stats"].pre_end_silence_detected = False
         start_frame = 0
         if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
            start_frame = max(cache["stats"].data_buf_start_frame,
                              cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
            self.OnVoiceStart(start_frame, cache=cache)
            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
            for t in range(start_frame + 1, cur_frm_idx + 1):
@@ -696,7 +711,8 @@
         if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            # silence timeout, return zero length decision
            if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
               cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
               cache[
                  "stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
               or (is_final_frame and cache["stats"].number_end_time_detected == 0):
               for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
                  self.OnSilenceDetected(t, cache=cache)
@@ -707,7 +723,8 @@
               if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
                  self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
         elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
            if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
            if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
               "stats"].max_end_sil_frame_cnt_thresh:
               lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
               if self.vad_opts.do_extend:
                  lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
@@ -731,6 +748,5 @@
      if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
         self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
         self.ResetDetection(cache=cache)