Yu Cao
2025-10-01 c4ac64fd5d24bb3fc8ccc441d36a07c83c8b9015
funasr/models/fsmn_vad_streaming/model.py
@@ -1,18 +1,23 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import logging
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
import json
import torch
from torch import nn
import math
from typing import Optional
import time
from funasr.register import tables
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
import math
import torch
import numpy as np
from torch import nn
from enum import Enum
from dataclasses import dataclass
from funasr.register import tables
from typing import List, Tuple, Dict, Any, Optional
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
class VadStateMachine(Enum):
    kVadInStateStartPointNotDetected = 1
@@ -40,44 +45,46 @@
    kVadSingleUtteranceDetectMode = 0
    kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(
            self,
            sample_rate: int = 16000,
            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
            snr_mode: int = 0,
            max_end_silence_time: int = 800,
            max_start_silence_time: int = 3000,
            do_start_point_detection: bool = True,
            do_end_point_detection: bool = True,
            window_size_ms: int = 200,
            sil_to_speech_time_thres: int = 150,
            speech_to_sil_time_thres: int = 150,
            speech_2_noise_ratio: float = 1.0,
            do_extend: int = 1,
            lookback_time_start_point: int = 200,
            lookahead_time_end_point: int = 100,
            max_single_segment_time: int = 60000,
            nn_eval_block_size: int = 8,
            dcd_block_size: int = 4,
            snr_thres: int = -100.0,
            noise_frame_num_used_for_snr: int = 100,
            decibel_thres: int = -100.0,
            speech_noise_thres: float = 0.6,
            fe_prior_thres: float = 1e-4,
            silence_pdf_num: int = 1,
            sil_pdf_ids: List[int] = [0],
            speech_noise_thresh_low: float = -0.1,
            speech_noise_thresh_high: float = 0.3,
            output_frame_probs: bool = False,
            frame_in_ms: int = 10,
            frame_length_ms: int = 25,
            **kwargs,
        self,
        sample_rate: int = 16000,
        detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
        snr_mode: int = 0,
        max_end_silence_time: int = 800,
        max_start_silence_time: int = 3000,
        do_start_point_detection: bool = True,
        do_end_point_detection: bool = True,
        window_size_ms: int = 200,
        sil_to_speech_time_thres: int = 150,
        speech_to_sil_time_thres: int = 150,
        speech_2_noise_ratio: float = 1.0,
        do_extend: int = 1,
        lookback_time_start_point: int = 200,
        lookahead_time_end_point: int = 100,
        max_single_segment_time: int = 60000,
        nn_eval_block_size: int = 8,
        dcd_block_size: int = 4,
        snr_thres: int = -100.0,
        noise_frame_num_used_for_snr: int = 100,
        decibel_thres: int = -100.0,
        speech_noise_thres: float = 0.6,
        fe_prior_thres: float = 1e-4,
        silence_pdf_num: int = 1,
        sil_pdf_ids: List[int] = [0],
        speech_noise_thresh_low: float = -0.1,
        speech_noise_thresh_high: float = 0.3,
        output_frame_probs: bool = False,
        frame_in_ms: int = 10,
        frame_length_ms: int = 25,
        **kwargs,
    ):
        self.sample_rate = sample_rate
        self.detect_mode = detect_mode
@@ -116,6 +123,7 @@
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.start_ms = 0
        self.end_ms = 0
@@ -139,6 +147,7 @@
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.noise_prob = 0.0
        self.speech_prob = 0.0
@@ -153,10 +162,14 @@
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, window_size_ms: int,
                 sil_to_speech_time: int,
                 speech_to_sil_time: int,
                 frame_size_ms: int):
    def __init__(
        self,
        window_size_ms: int,
        sil_to_speech_time: int,
        speech_to_sil_time: int,
        frame_size_ms: int,
    ):
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time = sil_to_speech_time
        self.speech_to_sil_time = speech_to_sil_time
@@ -189,7 +202,9 @@
    def GetWinSize(self) -> int:
        return int(self.win_size_frame)
    def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
    def DetectOneFrame(
        self, frameState: FrameState, frame_count: int, cache: dict = {}
    ) -> AudioChangeState:
        cur_frame_state = FrameState.kFrameStateSil
        if frameState == FrameState.kFrameStateSpeech:
            cur_frame_state = 1
@@ -202,11 +217,17 @@
        self.win_state[self.cur_win_pos] = cur_frame_state
        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
        if (
            self.pre_frame_state == FrameState.kFrameStateSil
            and self.win_sum >= self.sil_to_speech_frmcnt_thres
        ):
            self.pre_frame_state = FrameState.kFrameStateSpeech
            return AudioChangeState.kChangeStateSil2Speech
        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
        if (
            self.pre_frame_state == FrameState.kFrameStateSpeech
            and self.win_sum <= self.speech_to_sil_frmcnt_thres
        ):
            self.pre_frame_state = FrameState.kFrameStateSil
            return AudioChangeState.kChangeStateSpeech2Sil
@@ -220,38 +241,42 @@
        return int(self.frame_size_ms)
@dataclass
class StatsItem:
    # init variables
    data_buf_start_frame = 0
    frm_cnt = 0
    latest_confirmed_speech_frame = 0
    lastest_confirmed_silence_frame = -1
    continous_silence_frame_count = 0
    vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
    confirmed_start_frame = -1
    confirmed_end_frame = -1
    number_end_time_detected = 0
    sil_frame = 0
    sil_pdf_ids: list
    noise_average_decibel = -100.0
    pre_end_silence_detected = False
    next_seg = True # unused
    output_data_buf = []
    output_data_buf_offset = 0
    frame_probs = [] # unused
    max_end_sil_frame_cnt_thresh: int
    speech_noise_thres: float
    scores = None
    max_time_out = False #unused
    decibel = []
    data_buf = None
    data_buf_all = None
    waveform = None
    last_drop_frames = 0
class Stats(object):
    def __init__(
        self,
        sil_pdf_ids,
        max_end_sil_frame_cnt_thresh,
        speech_noise_thres,
    ):
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
        self.lastest_confirmed_silence_frame = -1
        self.continous_silence_frame_count = 0
        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.sil_frame = 0
        self.sil_pdf_ids = sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
        self.speech_noise_thres = speech_noise_thres
        self.scores = None
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.last_drop_frames = 0
@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
    """
@@ -259,19 +284,21 @@
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self,
                 encoder: str = None,
                 encoder_conf: Optional[Dict] = None,
                 vad_post_args: Dict[str, Any] = None,
                 **kwargs,
                 ):
    def __init__(
        self,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        vad_post_args: Dict[str, Any] = None,
        **kwargs,
    ):
        super().__init__()
        self.vad_opts = VADXOptions(**kwargs)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.encoder = encoder
        self.encoder_conf = encoder_conf
    def ResetDetection(self, cache: dict = {}):
        cache["stats"].continous_silence_frame_count = 0
@@ -289,7 +316,10 @@
            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
            cache["stats"].last_drop_frames = drop_frames
            cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
            cache["stats"].data_buf_all = cache["stats"].data_buf_all[
                real_drop_frames
                * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
            ]
            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
@@ -297,18 +327,31 @@
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if cache["stats"].data_buf_all is None:
            cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
            cache["stats"].data_buf_all = cache["stats"].waveform[
                0
            ]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
            cache["stats"].data_buf = cache["stats"].data_buf_all
        else:
            cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
        for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
            cache["stats"].decibel.append(
                10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                0.000001))
            cache["stats"].data_buf_all = torch.cat(
                (cache["stats"].data_buf_all, cache["stats"].waveform[0])
            )
        waveform_numpy = cache["stats"].waveform.numpy()
        offsets = np.arange(0, waveform_numpy.shape[1] - frame_sample_length + 1, frame_shift_length)
        frames = waveform_numpy[0, offsets[:, np.newaxis] + np.arange(frame_sample_length)]
        decibel_numpy = 10 * np.log10(np.sum(np.square(frames), axis=1) + 0.000001)
        decibel_numpy = decibel_numpy.tolist()
        cache["stats"].decibel.extend(decibel_numpy)
    def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
        scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
        scores = self.encoder(feats, cache=cache["encoder"]).to("cpu")  # return B * T * D
        assert (
            scores.shape[1] == feats.shape[1]
        ), "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
        if cache["stats"].scores is None:
@@ -316,25 +359,43 @@
        else:
            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
    def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
    def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
        while cache["stats"].data_buf_start_frame < frame_idx:
            if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
            if len(cache["stats"].data_buf) >= int(
                self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
            ):
                cache["stats"].data_buf_start_frame += 1
                cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
                cache["stats"].data_buf = cache["stats"].data_buf_all[
                    (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames)
                    * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
                ]
    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                           last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
    def PopDataToOutputBuf(
        self,
        start_frm: int,
        frm_cnt: int,
        first_frm_is_start_point: bool,
        last_frm_is_end_point: bool,
        end_point_is_sent_end: bool,
        cache: dict = {},
    ) -> None:
        self.PopDataBufTillFrame(start_frm, cache=cache)
        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
        expected_sample_number = int(
            frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
        )
        if last_frm_is_end_point:
            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
            extra_sample = max(
                0,
                int(
                    self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
                    - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
                ),
            )
            expected_sample_number += int(extra_sample)
        if end_point_is_sent_end:
            expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
        if len(cache["stats"].data_buf) < expected_sample_number:
            print('error in calling pop data_buf\n')
            print("error in calling pop data_buf\n")
        if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
            cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
@@ -344,27 +405,22 @@
            cache["stats"].output_data_buf[-1].doa = 0
        cur_seg = cache["stats"].output_data_buf[-1]
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning\n')
        out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
            print("warning\n")
        data_to_pop = 0
        if end_point_is_sent_end:
            data_to_pop = expected_sample_number
        else:
            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
            data_to_pop = int(
                frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
            )
        if data_to_pop > len(cache["stats"].data_buf):
            print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
            data_to_pop = len(cache["stats"].data_buf)
            expected_sample_number = len(cache["stats"].data_buf)
        cur_seg.doa = 0
        for sample_cpy_out in range(0, data_to_pop):
            # cur_seg.buffer[out_pos ++] = data_buf_.back();
            out_pos += 1
        for sample_cpy_out in range(data_to_pop, expected_sample_number):
            # cur_seg.buffer[out_pos++] = data_buf_.back()
            out_pos += 1
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('Something wrong with the VAD algorithm\n')
            print("Something wrong with the VAD algorithm\n")
        cache["stats"].data_buf_start_frame += frm_cnt
        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
        if first_frm_is_start_point:
@@ -376,39 +432,51 @@
        cache["stats"].lastest_confirmed_silence_frame = valid_frame
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataBufTillFrame(valid_frame, cache=cache)
        # silence_detected_callback_
        # pass
    def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
    # silence_detected_callback_
    # pass
    def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
        cache["stats"].latest_confirmed_speech_frame = valid_frame
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
        if self.vad_opts.do_start_point_detection:
            pass
        if cache["stats"].confirmed_start_frame != -1:
            print('not reset vad properly\n')
            print("not reset vad properly\n")
        else:
            cache["stats"].confirmed_start_frame = start_frame
        if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
        if (
            not fake_result
            and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
        ):
            self.PopDataToOutputBuf(
                cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache
            )
    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
    def OnVoiceEnd(
        self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}
    ) -> None:
        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
            self.OnVoiceDetected(t, cache=cache)
        if self.vad_opts.do_end_point_detection:
            pass
        if cache["stats"].confirmed_end_frame != -1:
            print('not reset vad properly\n')
            print("not reset vad properly\n")
        else:
            cache["stats"].confirmed_end_frame = end_frame
        if not fake_result:
            cache["stats"].sil_frame = 0
            self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
            self.PopDataToOutputBuf(
                cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache
            )
        cache["stats"].number_end_time_detected += 1
    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
    def MaybeOnVoiceEndIfLastFrame(
        self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}
    ) -> None:
        if is_final_frame:
            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
@@ -437,8 +505,17 @@
        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
        if len(cache["stats"].sil_pdf_ids) > 0:
            assert len(cache["stats"].scores) == 1  # 只支持batch_size = 1的测试
            sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
            sum_score = sum(sil_pdf_scores)
            """
            - Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
              and `torch.Tensor` is slower `float` when tensor has only one element.
            - Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
            - The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
              generation, which result in a mere (~2%) performance gain.
            """
            if len(cache["stats"].sil_pdf_ids) > 1:
                sum_score = sum(cache["stats"].scores[0][t][sil_pdf_id].item() for sil_pdf_id in cache["stats"].sil_pdf_ids)
            else:
                sum_score = cache["stats"].scores[0][t][cache["stats"].sil_pdf_ids[0]].item()
            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
            total_score = 1.0
            sum_score = total_score - sum_score
@@ -460,19 +537,27 @@
            if cache["stats"].noise_average_decibel < -99.9:
                cache["stats"].noise_average_decibel = cur_decibel
            else:
                cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
                        self.vad_opts.noise_frame_num_used_for_snr
                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
                cache["stats"].noise_average_decibel = (
                    cur_decibel
                    + cache["stats"].noise_average_decibel
                    * (self.vad_opts.noise_frame_num_used_for_snr - 1)
                ) / self.vad_opts.noise_frame_num_used_for_snr
        return frame_state
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
                is_final: bool = False
                ):
    def forward(
        self,
        feats: torch.Tensor,
        waveform: torch.tensor,
        cache: dict = {},
        is_final: bool = False,
        **kwargs,
    ):
        # if len(cache) == 0:
        #     self.AllResetDetection()
        # self.waveform = waveform  # compute decibel for each frame
        cache["stats"].waveform = waveform
        is_streaming_input = kwargs.get("is_streaming_input", True)
        self.ComputeDecibel(cache=cache)
        self.ComputeScores(feats, cache=cache)
        if not is_final:
@@ -483,13 +568,48 @@
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            if len(cache["stats"].output_data_buf) > 0:
                for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
                    if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
                        i].contain_seg_end_point):
                        continue
                    segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
                for i in range(
                    cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)
                ):
                    if (
                        is_streaming_input
                    ):  # in this case, return [beg, -1], [], [-1, end], [beg, end]
                        if not cache["stats"].output_data_buf[i].contain_seg_start_point:
                            continue
                        if (
                            not cache["stats"].next_seg
                            and not cache["stats"].output_data_buf[i].contain_seg_end_point
                        ):
                            continue
                        start_ms = (
                            cache["stats"].output_data_buf[i].start_ms
                            if cache["stats"].next_seg
                            else -1
                        )
                        if cache["stats"].output_data_buf[i].contain_seg_end_point:
                            end_ms = cache["stats"].output_data_buf[i].end_ms
                            cache["stats"].next_seg = True
                            cache["stats"].output_data_buf_offset += 1
                        else:
                            end_ms = -1
                            cache["stats"].next_seg = False
                        segment = [start_ms, end_ms]
                    else:  # in this case, return [beg, end]
                        if not is_final and (
                            not cache["stats"].output_data_buf[i].contain_seg_start_point
                            or not cache["stats"].output_data_buf[i].contain_seg_end_point
                        ):
                            continue
                        segment = [
                            cache["stats"].output_data_buf[i].start_ms,
                            cache["stats"].output_data_buf[i].end_ms,
                        ]
                        cache["stats"].output_data_buf_offset += 1  # need update this parameter
                    segment_batch.append(segment)
                    cache["stats"].output_data_buf_offset += 1  # need update this parameter
            if segment_batch:
                segments.append(segment_batch)
        # if is_final:
@@ -498,50 +618,72 @@
        return segments
    def init_cache(self, cache: dict = {}, **kwargs):
        cache["frontend"] = {}
        cache["prev_samples"] = torch.empty(0)
        cache["encoder"] = {}
        windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                          self.vad_opts.sil_to_speech_time_thres,
                                          self.vad_opts.speech_to_sil_time_thres,
                                          self.vad_opts.frame_in_ms)
        stats = StatsItem(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
                          max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
                          speech_noise_thres=self.vad_opts.speech_noise_thres,
                      )
        if kwargs.get("max_end_silence_time") is not None:
            # update the max_end_silence_time
            self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")
        windows_detector = WindowDetector(
            self.vad_opts.window_size_ms,
            self.vad_opts.sil_to_speech_time_thres,
            self.vad_opts.speech_to_sil_time_thres,
            self.vad_opts.frame_in_ms,
        )
        windows_detector.Reset()
        stats = Stats(
            sil_pdf_ids=self.vad_opts.sil_pdf_ids,
            max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time
            - self.vad_opts.speech_to_sil_time_thres,
            speech_noise_thres=self.vad_opts.speech_noise_thres,
        )
        cache["windows_detector"] = windows_detector
        cache["stats"] = stats
        return cache
    def generate(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
                 tokenizer=None,
                 frontend=None,
                 cache: dict = {},
                 **kwargs,
                 ):
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        if cache is None:
            cache = {}
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)
        meta_data = {}
        chunk_size = kwargs.get("chunk_size", 60000) # 50ms
        chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
        time1 = time.perf_counter()
        cfg = {"is_final": kwargs.get("is_final", False)}
        audio_sample_list = load_audio_text_image_video(data_in,
                                                        fs=frontend.fs,
                                                        audio_fs=kwargs.get("fs", 16000),
                                                        data_type=kwargs.get("data_type", "sound"),
                                                        tokenizer=tokenizer,
                                                        cache=cfg,
                                                        )
        is_streaming_input = (
            kwargs.get("is_streaming_input", False)
            if chunk_size >= 15000
            else kwargs.get("is_streaming_input", True)
        )
        is_final = (
            kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
        )
        cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
        audio_sample_list = load_audio_text_image_video(
            data_in,
            fs=frontend.fs,
            audio_fs=kwargs.get("fs", 16000),
            data_type=kwargs.get("data_type", "sound"),
            tokenizer=tokenizer,
            cache=cfg,
        )
        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
        is_streaming_input = cfg["is_streaming_input"]
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        assert len(audio_sample_list) == 1, "batch_size must be set 1"
@@ -553,58 +695,72 @@
        segments = []
        for i in range(n):
            kwargs["is_final"] = _is_final and i == n - 1
            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
            audio_sample_i = audio_sample[i * chunk_stride_samples : (i + 1) * chunk_stride_samples]
            # extract fbank feats
            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend, cache=cache["frontend"],
                                                   is_final=kwargs["is_final"])
            speech, speech_lengths = extract_fbank(
                [audio_sample_i],
                data_type=kwargs.get("data_type", "sound"),
                frontend=frontend,
                cache=cache["frontend"],
                is_final=kwargs["is_final"],
            )
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            meta_data["batch_data_time"] = (
                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            )
            speech = speech.to(device=kwargs["device"])
            speech_lengths = speech_lengths.to(device=kwargs["device"])
            batch = {
                "feats": speech,
                "waveform": cache["frontend"]["waveforms"],
                "is_final": kwargs["is_final"],
                "cache": cache
                "cache": cache,
                "is_streaming_input": is_streaming_input,
            }
            segments_i = self.forward(**batch)
            if len(segments_i) > 0:
                segments.extend(*segments_i)
        cache["prev_samples"] = audio_sample[:-m]
        if _is_final:
            self.init_cache(cache, **kwargs)
            self.init_cache(cache)
        ibest_writer = None
        if ibest_writer is None and kwargs.get("output_dir") is not None:
            writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = writer[f"{1}best_recog"]
        if kwargs.get("output_dir") is not None:
            if not hasattr(self, "writer"):
                self.writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = self.writer[f"{1}best_recog"]
        results = []
        result_i = {"key": key[0], "value": segments}
        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
            result_i = json.dumps(result_i)
        # if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
        #    result_i = json.dumps(result_i)
        results.append(result_i)
        if ibest_writer is not None:
            ibest_writer["text"][key[0]] = segments
        return results, meta_data
    def export(self, **kwargs):
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models
    def DetectCommonFrames(self, cache: dict = {}) -> int:
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
            frame_state = self.GetFrameState(
                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
            )
            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
        return 0
@@ -614,7 +770,9 @@
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
            frame_state = self.GetFrameState(
                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
            )
            if i != 0:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
            else:
@@ -622,7 +780,9 @@
        return 0
    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
    def DetectOneFrame(
        self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}
    ) -> None:
        tmp_cur_frm_state = FrameState.kFrameStateInvalid
        if cur_frm_state == FrameState.kFrameStateSpeech:
            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
@@ -631,7 +791,9 @@
                tmp_cur_frm_state = FrameState.kFrameStateSil
        elif cur_frm_state == FrameState.kFrameStateSil:
            tmp_cur_frm_state = FrameState.kFrameStateSil
        state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
        state_change = cache["windows_detector"].DetectOneFrame(
            tmp_cur_frm_state, cur_frm_idx, cache=cache
        )
        frm_shift_in_ms = self.vad_opts.frame_in_ms
        if AudioChangeState.kChangeStateSil2Speech == state_change:
            silence_frame_count = cache["stats"].continous_silence_frame_count
@@ -639,7 +801,10 @@
            cache["stats"].pre_end_silence_detected = False
            start_frame = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
                start_frame = max(
                    cache["stats"].data_buf_start_frame,
                    cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache),
                )
                self.OnVoiceStart(start_frame, cache=cache)
                cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
                for t in range(start_frame + 1, cur_frm_idx + 1):
@@ -647,8 +812,10 @@
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
                    self.OnVoiceDetected(t, cache=cache)
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                if (
                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
                ):
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
@@ -662,8 +829,10 @@
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                pass
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                if (
                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
                ):
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
@@ -675,8 +844,10 @@
        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
            cache["stats"].continous_silence_frame_count = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                if (
                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
                ):
                    cache["stats"].max_time_out = True
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
@@ -690,9 +861,13 @@
            cache["stats"].continous_silence_frame_count += 1
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                # silence timeout, return zero length decision
                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
                        cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
                        or (is_final_frame and cache["stats"].number_end_time_detected == 0):
                if (
                    (self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value)
                    and (
                        cache["stats"].continous_silence_frame_count * frm_shift_in_ms
                        > self.vad_opts.max_start_silence_time
                    )
                ) or (is_final_frame and cache["stats"].number_end_time_detected == 0):
                    for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
                        self.OnSilenceDetected(t, cache=cache)
                    self.OnVoiceStart(0, True, cache=cache)
@@ -700,32 +875,43 @@
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                else:
                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
                        self.OnSilenceDetected(
                            cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache
                        )
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
                    lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
                if (
                    cache["stats"].continous_silence_frame_count * frm_shift_in_ms
                    >= cache["stats"].max_end_sil_frame_cnt_thresh
                ):
                    lookback_frame = int(
                        cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms
                    )
                    if self.vad_opts.do_extend:
                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
                        lookback_frame -= int(
                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
                        )
                        lookback_frame -= 1
                        lookback_frame = max(0, lookback_frame)
                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                elif (
                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
                ):
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif self.vad_opts.do_extend and not is_final_frame:
                    if cache["stats"].continous_silence_frame_count <= int(
                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
                        self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
                    ):
                        self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
        if (
            cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
            and self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value
        ):
            self.ResetDetection(cache=cache)