| | |
| | | from typing import List, Tuple, Dict, Any, Optional |
| | | |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | | class VadStateMachine(Enum): |
| | |
| | | kVadInStateInSpeechSegment = 2 |
| | | kVadInStateEndPointDetected = 3 |
| | | |
| | | |
| | | class FrameState(Enum): |
| | | kFrameStateInvalid = -1 |
| | | kFrameStateSpeech = 1 |
| | | kFrameStateSil = 0 |
| | | |
| | | |
| | | # final voice/unvoice state per frame |
| | | class AudioChangeState(Enum): |
| | |
| | | kChangeStateNoBegin = 4 |
| | | kChangeStateInvalid = 5 |
| | | |
| | | |
| | | class VadDetectMode(Enum): |
| | | kVadSingleUtteranceDetectMode = 0 |
| | | kVadMutipleUtteranceDetectMode = 1 |
| | | |
| | | |
| | | class VADXOptions: |
| | | """ |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | sample_rate: int = 16000, |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self): |
| | | self.start_ms = 0 |
| | | self.end_ms = 0 |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self): |
| | | self.noise_prob = 0.0 |
| | | self.speech_prob = 0.0 |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self, window_size_ms: int, |
| | | sil_to_speech_time: int, |
| | | speech_to_sil_time: int, |
| | |
| | | def GetWinSize(self) -> int: |
| | | return int(self.win_size_frame) |
| | | |
| | | def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState: |
| | | def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState: |
| | | cur_frame_state = FrameState.kFrameStateSil |
| | | if frameState == FrameState.kFrameStateSpeech: |
| | | cur_frame_state = 1 |
| | |
| | | def FrameSizeMs(self) -> int: |
| | | return int(self.frame_size_ms) |
| | | |
| | | |
| | | class Stats(object): |
| | | def __init__(self, |
| | | sil_pdf_ids, |
| | | max_end_sil_frame_cnt_thresh, |
| | | speech_noise_thres, |
| | | ): |
| | | |
| | | self.data_buf_start_frame = 0 |
| | | self.frm_cnt = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self, |
| | | encoder: str = None, |
| | | encoder_conf: Optional[Dict] = None, |
| | |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(**encoder_conf) |
| | | self.encoder = encoder |
| | | |
| | | self.encoder_conf = encoder_conf |
| | | |
| | | def ResetDetection(self, cache: dict = {}): |
| | | cache["stats"].continous_silence_frame_count = 0 |
| | |
| | | drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) |
| | | real_drop_frames = drop_frames - cache["stats"].last_drop_frames |
| | | cache["stats"].last_drop_frames = drop_frames |
| | | cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:] |
| | | cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :] |
| | | |
| | |
| | | frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) |
| | | frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | if cache["stats"].data_buf_all is None: |
| | | cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] |
| | | cache["stats"].data_buf_all = cache["stats"].waveform[ |
| | | 0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] |
| | | cache["stats"].data_buf = cache["stats"].data_buf_all |
| | | else: |
| | | cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0])) |
| | |
| | | else: |
| | | cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1) |
| | | |
| | | def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again |
| | | def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None: # need check again |
| | | while cache["stats"].data_buf_start_frame < frame_idx: |
| | | if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): |
| | | cache["stats"].data_buf_start_frame += 1 |
| | | cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | cache["stats"].data_buf = cache["stats"].data_buf_all[ |
| | | (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | |
| | | def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, |
| | | last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None: |
| | | last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None: |
| | | self.PopDataBufTillFrame(start_frm, cache=cache) |
| | | expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) |
| | | if last_frm_is_end_point: |
| | |
| | | cache["stats"].lastest_confirmed_silence_frame = valid_frame |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataBufTillFrame(valid_frame, cache=cache) |
| | | # silence_detected_callback_ |
| | | # pass |
| | | |
| | | def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None: |
| | | # silence_detected_callback_ |
| | | # pass |
| | | |
| | | def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None: |
| | | cache["stats"].latest_confirmed_speech_frame = valid_frame |
| | | self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache) |
| | | |
| | | def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None: |
| | | def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None: |
| | | if self.vad_opts.do_start_point_detection: |
| | | pass |
| | | if cache["stats"].confirmed_start_frame != -1: |
| | |
| | | if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache) |
| | | |
| | | def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None: |
| | | def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None: |
| | | for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame): |
| | | self.OnVoiceDetected(t, cache=cache) |
| | | if self.vad_opts.do_end_point_detection: |
| | |
| | | |
| | | return frame_state |
| | | |
| | | def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {}, |
| | | is_final: bool = False |
| | | def forward(self, feats: torch.Tensor, |
| | | waveform: torch.tensor, |
| | | cache: dict = {}, |
| | | is_final: bool = False, |
| | | **kwargs, |
| | | ): |
| | | # if len(cache) == 0: |
| | | # self.AllResetDetection() |
| | | # self.waveform = waveform # compute decibel for each frame |
| | | cache["stats"].waveform = waveform |
| | | is_streaming_input = kwargs.get("is_streaming_input", True) |
| | | self.ComputeDecibel(cache=cache) |
| | | self.ComputeScores(feats, cache=cache) |
| | | if not is_final: |
| | |
| | | segment_batch = [] |
| | | if len(cache["stats"].output_data_buf) > 0: |
| | | for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)): |
| | | if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ |
| | | i].contain_seg_end_point): |
| | | continue |
| | | segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] |
| | | if is_streaming_input: # in this case, return [beg, -1], [], [-1, end], [beg, end] |
| | | if not cache["stats"].output_data_buf[i].contain_seg_start_point: |
| | | continue |
| | | if not cache["stats"].next_seg and not cache["stats"].output_data_buf[i].contain_seg_end_point: |
| | | continue |
| | | start_ms = cache["stats"].output_data_buf[i].start_ms if cache["stats"].next_seg else -1 |
| | | if cache["stats"].output_data_buf[i].contain_seg_end_point: |
| | | end_ms = cache["stats"].output_data_buf[i].end_ms |
| | | cache["stats"].next_seg = True |
| | | cache["stats"].output_data_buf_offset += 1 |
| | | else: |
| | | end_ms = -1 |
| | | cache["stats"].next_seg = False |
| | | segment = [start_ms, end_ms] |
| | | |
| | | else: # in this case, return [beg, end] |
| | | |
| | | if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not |
| | | cache["stats"].output_data_buf[ |
| | | i].contain_seg_end_point): |
| | | continue |
| | | segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] |
| | | cache["stats"].output_data_buf_offset += 1 # need update this parameter |
| | | |
| | | segment_batch.append(segment) |
| | | cache["stats"].output_data_buf_offset += 1 # need update this parameter |
| | | |
| | | if segment_batch: |
| | | segments.append(segment_batch) |
| | | # if is_final: |
| | |
| | | return segments |
| | | |
| | | def init_cache(self, cache: dict = {}, **kwargs): |
| | | |
| | | cache["frontend"] = {} |
| | | cache["prev_samples"] = torch.empty(0) |
| | | cache["encoder"] = {} |
| | |
| | | self.init_cache(cache, **kwargs) |
| | | |
| | | meta_data = {} |
| | | chunk_size = kwargs.get("chunk_size", 60000) # 50ms |
| | | chunk_size = kwargs.get("chunk_size", 60000) # 50ms |
| | | chunk_stride_samples = int(chunk_size * frontend.fs / 1000) |
| | | |
| | | time1 = time.perf_counter() |
| | | cfg = {"is_final": kwargs.get("is_final", False)} |
| | | is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True) |
| | | is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True) |
| | | cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input} |
| | | audio_sample_list = load_audio_text_image_video(data_in, |
| | | fs=frontend.fs, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | |
| | | cache=cfg, |
| | | ) |
| | | _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True |
| | | |
| | | is_streaming_input = cfg["is_streaming_input"] |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | assert len(audio_sample_list) == 1, "batch_size must be set 1" |
| | |
| | | "feats": speech, |
| | | "waveform": cache["frontend"]["waveforms"], |
| | | "is_final": kwargs["is_final"], |
| | | "cache": cache |
| | | "cache": cache, |
| | | "is_streaming_input": is_streaming_input |
| | | } |
| | | segments_i = self.forward(**batch) |
| | | if len(segments_i) > 0: |
| | | segments.extend(*segments_i) |
| | | |
| | | |
| | | cache["prev_samples"] = audio_sample[:-m] |
| | | if _is_final: |
| | | cache = {} |
| | | self.init_cache(cache) |
| | | |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{1}best_recog"] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "value": segments} |
| | |
| | | if ibest_writer is not None: |
| | | ibest_writer["text"][key[0]] = segments |
| | | |
| | | |
| | | return results, meta_data |
| | | |
| | | def export(self, **kwargs): |
| | | is_onnx = kwargs.get("type", "onnx") == "onnx" |
| | | encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") |
| | | self.encoder = encoder_class(self.encoder, onnx=is_onnx) |
| | | self.forward = self._export_forward |
| | | |
| | | return self |
| | | |
| | | def export_forward(self, feats: torch.Tensor, *args, **kwargs): |
| | | |
| | | scores, out_caches = self.encoder(feats, *args) |
| | | |
| | | return scores, out_caches |
| | | |
| | | def export_dummy_inputs(self, data_in=None, frame=30): |
| | | if data_in is None: |
| | | speech = torch.randn(1, frame, self.encoder_conf.get("input_dim")) |
| | | else: |
| | | speech = None # Undo |
| | | |
| | | cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1 |
| | | in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1) |
| | | in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1) |
| | | in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1) |
| | | in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1) |
| | | |
| | | return (speech, in_cache0, in_cache1, in_cache2, in_cache3) |
| | | |
| | | def export_input_names(self): |
| | | return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3'] |
| | | |
| | | def export_output_names(self): |
| | | return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3'] |
| | | |
| | | def export_dynamic_axes(self): |
| | | return { |
| | | 'speech': { |
| | | 1: 'feats_length' |
| | | }, |
| | | } |
| | | |
| | | def export_name(self, ): |
| | | return "model.onnx" |
| | | |
| | | def DetectCommonFrames(self, cache: dict = {}) -> int: |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, |
| | | cache=cache) |
| | | self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) |
| | | |
| | | return 0 |
| | |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, |
| | | cache=cache) |
| | | if i != 0: |
| | | self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) |
| | | else: |
| | |
| | | |
| | | return 0 |
| | | |
| | | def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: |
| | | def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, |
| | | cache: dict = {}) -> None: |
| | | tmp_cur_frm_state = FrameState.kFrameStateInvalid |
| | | if cur_frm_state == FrameState.kFrameStateSpeech: |
| | | if math.fabs(1.0) > self.vad_opts.fe_prior_thres: |
| | |
| | | cache["stats"].pre_end_silence_detected = False |
| | | start_frame = 0 |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) |
| | | start_frame = max(cache["stats"].data_buf_start_frame, |
| | | cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) |
| | | self.OnVoiceStart(start_frame, cache=cache) |
| | | cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment |
| | | for t in range(start_frame + 1, cur_frm_idx + 1): |
| | |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | # silence timeout, return zero length decision |
| | | if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( |
| | | cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ |
| | | cache[ |
| | | "stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ |
| | | or (is_final_frame and cache["stats"].number_end_time_detected == 0): |
| | | for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx): |
| | | self.OnSilenceDetected(t, cache=cache) |
| | |
| | | if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache): |
| | | self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache) |
| | | elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh: |
| | | if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[ |
| | | "stats"].max_end_sil_frame_cnt_thresh: |
| | | lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms) |
| | | if self.vad_opts.do_extend: |
| | | lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) |
| | |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ |
| | | self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: |
| | | self.ResetDetection(cache=cache) |
| | | |
| | | |
| | | |