zhifu gao
2024-01-18 b28f3c9da94ae72a3a0b7bb5982b587be7cf4cd6
fsmn-vad bugfix (#1270)

* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* update with main (#1264)

* Funasr1.0 (#1261)

* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* bug fix

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* funasr1.0 sanm scama

* funasr1.0 infer_after_finetune

* funasr1.0 fsmn-vad bug fix

* funasr1.0 fsmn-vad bug fix

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
8个文件已修改
7个文件已添加
3 文件已重命名
2733 ■■■■ 已修改文件
README.md 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README_zh.md 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/scama/demo.py 42 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/scama/infer.sh 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/model.py 1374 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/template.yaml 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/decoder.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/encoder.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/model.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/template.yaml 121 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/decoder.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/encoder.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/model.py 669 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/template.yaml 127 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/uniasr/model.py 118 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/uniasr/template.yaml 178 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -91,12 +91,13 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
                  vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
                  punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
                  spk_model="cam++", spk_model_revision="v2.0.2")
model = AutoModel(model="paraformer-zh", model_revision="v2.0.2",
                  vad_model="fsmn-vad", vad_model_revision="v2.0.2",
                  punc_model="ct-punc-c", punc_model_revision="v2.0.2",
                  # spk_model="cam++", spk_model_revision="v2.0.2",
                  )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
                     batch_size=64,
                     batch_size_s=300,
                     hotword='魔搭')
print(res)
```
README_zh.md
@@ -87,12 +87,13 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
                  vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
                  punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
                  spk_model="cam++", spk_model_revision="v2.0.2")
model = AutoModel(model="paraformer-zh", model_revision="v2.0.2",
                  vad_model="fsmn-vad", vad_model_revision="v2.0.2",
                  punc_model="ct-punc-c", punc_model_revision="v2.0.2",
                  # spk_model="cam++", spk_model_revision="v2.0.2",
                  )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
            batch_size=64,
            batch_size_s=300,
            hotword='魔搭')
print(res)
```
examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh
New file
@@ -0,0 +1,12 @@
python funasr/bin/inference.py \
--config-path="/Users/zhifu/funasr_github/test_local/funasr_cli_egs" \
--config-name="config.yaml" \
++init_param="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/model.pt" \
+tokenizer_conf.token_list="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/tokens.txt" \
+frontend_conf.cmvn_file="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/am.mvn" \
+input="data/wav.scp" \
+output_dir="./outputs/debug" \
+device="cuda" \
examples/industrial_data_pretraining/scama/demo.py
New file
@@ -0,0 +1,42 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_SCAMA_asr-zh-cn-16k-common-vocab8358-streaming", model_revision="v2.0.2")
cache = {}
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
            chunk_size=chunk_size,
            encoder_chunk_look_back=encoder_chunk_look_back,
            decoder_chunk_look_back=decoder_chunk_look_back,
            )
print(res)
import soundfile
import os
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
    speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
    is_final = i == total_chunk_num - 1
    res = model.generate(input=speech_chunk,
                         cache=cache,
                         is_final=is_final,
                         chunk_size=chunk_size,
                         encoder_chunk_look_back=encoder_chunk_look_back,
                         decoder_chunk_look_back=decoder_chunk_look_back,
                         )
    print(res)
examples/industrial_data_pretraining/scama/infer.sh
New file
@@ -0,0 +1,11 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
+output_dir="./outputs/debug" \
+device="cpu" \
funasr/models/fsmn_vad_streaming/model.py
@@ -19,714 +19,718 @@
class VadStateMachine(Enum):
    kVadInStateStartPointNotDetected = 1
    kVadInStateInSpeechSegment = 2
    kVadInStateEndPointDetected = 3
    kVadInStateStartPointNotDetected = 1
    kVadInStateInSpeechSegment = 2
    kVadInStateEndPointDetected = 3
class FrameState(Enum):
    kFrameStateInvalid = -1
    kFrameStateSpeech = 1
    kFrameStateSil = 0
    kFrameStateInvalid = -1
    kFrameStateSpeech = 1
    kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
    kChangeStateSpeech2Speech = 0
    kChangeStateSpeech2Sil = 1
    kChangeStateSil2Sil = 2
    kChangeStateSil2Speech = 3
    kChangeStateNoBegin = 4
    kChangeStateInvalid = 5
    kChangeStateSpeech2Speech = 0
    kChangeStateSpeech2Sil = 1
    kChangeStateSil2Sil = 2
    kChangeStateSil2Speech = 3
    kChangeStateNoBegin = 4
    kChangeStateInvalid = 5
class VadDetectMode(Enum):
    kVadSingleUtteranceDetectMode = 0
    kVadMutipleUtteranceDetectMode = 1
    kVadSingleUtteranceDetectMode = 0
    kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(
            self,
            sample_rate: int = 16000,
            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
            snr_mode: int = 0,
            max_end_silence_time: int = 800,
            max_start_silence_time: int = 3000,
            do_start_point_detection: bool = True,
            do_end_point_detection: bool = True,
            window_size_ms: int = 200,
            sil_to_speech_time_thres: int = 150,
            speech_to_sil_time_thres: int = 150,
            speech_2_noise_ratio: float = 1.0,
            do_extend: int = 1,
            lookback_time_start_point: int = 200,
            lookahead_time_end_point: int = 100,
            max_single_segment_time: int = 60000,
            nn_eval_block_size: int = 8,
            dcd_block_size: int = 4,
            snr_thres: int = -100.0,
            noise_frame_num_used_for_snr: int = 100,
            decibel_thres: int = -100.0,
            speech_noise_thres: float = 0.6,
            fe_prior_thres: float = 1e-4,
            silence_pdf_num: int = 1,
            sil_pdf_ids: List[int] = [0],
            speech_noise_thresh_low: float = -0.1,
            speech_noise_thresh_high: float = 0.3,
            output_frame_probs: bool = False,
            frame_in_ms: int = 10,
            frame_length_ms: int = 25,
            **kwargs,
    ):
        self.sample_rate = sample_rate
        self.detect_mode = detect_mode
        self.snr_mode = snr_mode
        self.max_end_silence_time = max_end_silence_time
        self.max_start_silence_time = max_start_silence_time
        self.do_start_point_detection = do_start_point_detection
        self.do_end_point_detection = do_end_point_detection
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time_thres = sil_to_speech_time_thres
        self.speech_to_sil_time_thres = speech_to_sil_time_thres
        self.speech_2_noise_ratio = speech_2_noise_ratio
        self.do_extend = do_extend
        self.lookback_time_start_point = lookback_time_start_point
        self.lookahead_time_end_point = lookahead_time_end_point
        self.max_single_segment_time = max_single_segment_time
        self.nn_eval_block_size = nn_eval_block_size
        self.dcd_block_size = dcd_block_size
        self.snr_thres = snr_thres
        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
        self.decibel_thres = decibel_thres
        self.speech_noise_thres = speech_noise_thres
        self.fe_prior_thres = fe_prior_thres
        self.silence_pdf_num = silence_pdf_num
        self.sil_pdf_ids = sil_pdf_ids
        self.speech_noise_thresh_low = speech_noise_thresh_low
        self.speech_noise_thresh_high = speech_noise_thresh_high
        self.output_frame_probs = output_frame_probs
        self.frame_in_ms = frame_in_ms
        self.frame_length_ms = frame_length_ms
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(
        self,
        sample_rate: int = 16000,
        detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
        snr_mode: int = 0,
        max_end_silence_time: int = 800,
        max_start_silence_time: int = 3000,
        do_start_point_detection: bool = True,
        do_end_point_detection: bool = True,
        window_size_ms: int = 200,
        sil_to_speech_time_thres: int = 150,
        speech_to_sil_time_thres: int = 150,
        speech_2_noise_ratio: float = 1.0,
        do_extend: int = 1,
        lookback_time_start_point: int = 200,
        lookahead_time_end_point: int = 100,
        max_single_segment_time: int = 60000,
        nn_eval_block_size: int = 8,
        dcd_block_size: int = 4,
        snr_thres: int = -100.0,
        noise_frame_num_used_for_snr: int = 100,
        decibel_thres: int = -100.0,
        speech_noise_thres: float = 0.6,
        fe_prior_thres: float = 1e-4,
        silence_pdf_num: int = 1,
        sil_pdf_ids: List[int] = [0],
        speech_noise_thresh_low: float = -0.1,
        speech_noise_thresh_high: float = 0.3,
        output_frame_probs: bool = False,
        frame_in_ms: int = 10,
        frame_length_ms: int = 25,
        **kwargs,
    ):
        self.sample_rate = sample_rate
        self.detect_mode = detect_mode
        self.snr_mode = snr_mode
        self.max_end_silence_time = max_end_silence_time
        self.max_start_silence_time = max_start_silence_time
        self.do_start_point_detection = do_start_point_detection
        self.do_end_point_detection = do_end_point_detection
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time_thres = sil_to_speech_time_thres
        self.speech_to_sil_time_thres = speech_to_sil_time_thres
        self.speech_2_noise_ratio = speech_2_noise_ratio
        self.do_extend = do_extend
        self.lookback_time_start_point = lookback_time_start_point
        self.lookahead_time_end_point = lookahead_time_end_point
        self.max_single_segment_time = max_single_segment_time
        self.nn_eval_block_size = nn_eval_block_size
        self.dcd_block_size = dcd_block_size
        self.snr_thres = snr_thres
        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
        self.decibel_thres = decibel_thres
        self.speech_noise_thres = speech_noise_thres
        self.fe_prior_thres = fe_prior_thres
        self.silence_pdf_num = silence_pdf_num
        self.sil_pdf_ids = sil_pdf_ids
        self.speech_noise_thresh_low = speech_noise_thresh_low
        self.speech_noise_thresh_high = speech_noise_thresh_high
        self.output_frame_probs = output_frame_probs
        self.frame_in_ms = frame_in_ms
        self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
    def Reset(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
    def Reset(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
class E2EVadFrameProb(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.noise_prob = 0.0
        self.speech_prob = 0.0
        self.score = 0.0
        self.frame_id = 0
        self.frm_state = 0
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.noise_prob = 0.0
        self.speech_prob = 0.0
        self.score = 0.0
        self.frame_id = 0
        self.frm_state = 0
class WindowDetector(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, window_size_ms: int,
                 sil_to_speech_time: int,
                 speech_to_sil_time: int,
                 frame_size_ms: int):
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time = sil_to_speech_time
        self.speech_to_sil_time = speech_to_sil_time
        self.frame_size_ms = frame_size_ms
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, window_size_ms: int,
                 sil_to_speech_time: int,
                 speech_to_sil_time: int,
                 frame_size_ms: int):
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time = sil_to_speech_time
        self.speech_to_sil_time = speech_to_sil_time
        self.frame_size_ms = frame_size_ms
        self.win_size_frame = int(window_size_ms / frame_size_ms)
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame  # 初始化窗
        self.cur_win_pos = 0
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def Reset(self) -> None:
        self.cur_win_pos = 0
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def GetWinSize(self) -> int:
        return int(self.win_size_frame)
    def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
        cur_frame_state = FrameState.kFrameStateSil
        if frameState == FrameState.kFrameStateSpeech:
            cur_frame_state = 1
        elif frameState == FrameState.kFrameStateSil:
            cur_frame_state = 0
        else:
            return AudioChangeState.kChangeStateInvalid
        self.win_sum -= self.win_state[self.cur_win_pos]
        self.win_sum += cur_frame_state
        self.win_state[self.cur_win_pos] = cur_frame_state
        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSpeech
            return AudioChangeState.kChangeStateSil2Speech
        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSil
            return AudioChangeState.kChangeStateSpeech2Sil
        if self.pre_frame_state == FrameState.kFrameStateSil:
            return AudioChangeState.kChangeStateSil2Sil
        if self.pre_frame_state == FrameState.kFrameStateSpeech:
            return AudioChangeState.kChangeStateSpeech2Speech
        return AudioChangeState.kChangeStateInvalid
    def FrameSizeMs(self) -> int:
        return int(self.frame_size_ms)
        self.win_size_frame = int(window_size_ms / frame_size_ms)
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame  # 初始化窗
        self.cur_win_pos = 0
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def Reset(self) -> None:
        self.cur_win_pos = 0
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def GetWinSize(self) -> int:
        return int(self.win_size_frame)
    def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
        cur_frame_state = FrameState.kFrameStateSil
        if frameState == FrameState.kFrameStateSpeech:
            cur_frame_state = 1
        elif frameState == FrameState.kFrameStateSil:
            cur_frame_state = 0
        else:
            return AudioChangeState.kChangeStateInvalid
        self.win_sum -= self.win_state[self.cur_win_pos]
        self.win_sum += cur_frame_state
        self.win_state[self.cur_win_pos] = cur_frame_state
        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSpeech
            return AudioChangeState.kChangeStateSil2Speech
        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSil
            return AudioChangeState.kChangeStateSpeech2Sil
        if self.pre_frame_state == FrameState.kFrameStateSil:
            return AudioChangeState.kChangeStateSil2Sil
        if self.pre_frame_state == FrameState.kFrameStateSpeech:
            return AudioChangeState.kChangeStateSpeech2Speech
        return AudioChangeState.kChangeStateInvalid
    def FrameSizeMs(self) -> int:
        return int(self.frame_size_ms)
class Stats(object):
    def __init__(self,
                 sil_pdf_ids,
                 max_end_sil_frame_cnt_thresh,
                 speech_noise_thres,
                 ):
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
        self.lastest_confirmed_silence_frame = -1
        self.continous_silence_frame_count = 0
        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.sil_frame = 0
        self.sil_pdf_ids = sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
        self.speech_noise_thres = speech_noise_thres
        self.scores = None
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.last_drop_frames = 0
@dataclass
class StatsItem:
    # init variables
    data_buf_start_frame = 0
    frm_cnt = 0
    latest_confirmed_speech_frame = 0
    lastest_confirmed_silence_frame = -1
    continous_silence_frame_count = 0
    vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
    confirmed_start_frame = -1
    confirmed_end_frame = -1
    number_end_time_detected = 0
    sil_frame = 0
    sil_pdf_ids: list
    noise_average_decibel = -100.0
    pre_end_silence_detected = False
    next_seg = True # unused
    output_data_buf = []
    output_data_buf_offset = 0
    frame_probs = [] # unused
    max_end_sil_frame_cnt_thresh: int
    speech_noise_thres: float
    scores = None
    max_time_out = False #unused
    decibel = []
    data_buf = None
    data_buf_all = None
    waveform = None
    last_drop_frames = 0
@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self,
                 encoder: str = None,
                 encoder_conf: Optional[Dict] = None,
                 vad_post_args: Dict[str, Any] = None,
                 **kwargs,
                 ):
        super().__init__()
        self.vad_opts = VADXOptions(**kwargs)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.encoder = encoder
    def ResetDetection(self, cache: dict = {}):
        cache["stats"].continous_silence_frame_count = 0
        cache["stats"].latest_confirmed_speech_frame = 0
        cache["stats"].lastest_confirmed_silence_frame = -1
        cache["stats"].confirmed_start_frame = -1
        cache["stats"].confirmed_end_frame = -1
        cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        cache["windows_detector"].Reset()
        cache["stats"].sil_frame = 0
        cache["stats"].frame_probs = []
        if cache["stats"].output_data_buf:
            assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
            cache["stats"].last_drop_frames = drop_frames
            cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
    def ComputeDecibel(self, cache: dict = {}) -> None:
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if cache["stats"].data_buf_all is None:
            cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
            cache["stats"].data_buf = cache["stats"].data_buf_all
        else:
            cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
        for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
            cache["stats"].decibel.append(
                10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                0.000001))
    def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
        scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
        if cache["stats"].scores is None:
            cache["stats"].scores = scores  # the first calculation
        else:
            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
    def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
        while cache["stats"].data_buf_start_frame < frame_idx:
            if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                cache["stats"].data_buf_start_frame += 1
                cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                           last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
        self.PopDataBufTillFrame(start_frm, cache=cache)
        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
        if last_frm_is_end_point:
            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
            expected_sample_number += int(extra_sample)
        if end_point_is_sent_end:
            expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
        if len(cache["stats"].data_buf) < expected_sample_number:
            print('error in calling pop data_buf\n')
        if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
            cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
            cache["stats"].output_data_buf[-1].Reset()
            cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
            cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
            cache["stats"].output_data_buf[-1].doa = 0
        cur_seg = cache["stats"].output_data_buf[-1]
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning\n')
        out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
        data_to_pop = 0
        if end_point_is_sent_end:
            data_to_pop = expected_sample_number
        else:
            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if data_to_pop > len(cache["stats"].data_buf):
            print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
            data_to_pop = len(cache["stats"].data_buf)
            expected_sample_number = len(cache["stats"].data_buf)
        cur_seg.doa = 0
        for sample_cpy_out in range(0, data_to_pop):
            # cur_seg.buffer[out_pos ++] = data_buf_.back();
            out_pos += 1
        for sample_cpy_out in range(data_to_pop, expected_sample_number):
            # cur_seg.buffer[out_pos++] = data_buf_.back()
            out_pos += 1
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('Something wrong with the VAD algorithm\n')
        cache["stats"].data_buf_start_frame += frm_cnt
        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
        if first_frm_is_start_point:
            cur_seg.contain_seg_start_point = True
        if last_frm_is_end_point:
            cur_seg.contain_seg_end_point = True
    def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
        cache["stats"].lastest_confirmed_silence_frame = valid_frame
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataBufTillFrame(valid_frame, cache=cache)
        # silence_detected_callback_
        # pass
    def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
        cache["stats"].latest_confirmed_speech_frame = valid_frame
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
        if self.vad_opts.do_start_point_detection:
            pass
        if cache["stats"].confirmed_start_frame != -1:
            print('not reset vad properly\n')
        else:
            cache["stats"].confirmed_start_frame = start_frame
        if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
            self.OnVoiceDetected(t, cache=cache)
        if self.vad_opts.do_end_point_detection:
            pass
        if cache["stats"].confirmed_end_frame != -1:
            print('not reset vad properly\n')
        else:
            cache["stats"].confirmed_end_frame = end_frame
        if not fake_result:
            cache["stats"].sil_frame = 0
            self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
        cache["stats"].number_end_time_detected += 1
    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
        if is_final_frame:
            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
    def GetLatency(self, cache: dict = {}) -> int:
        return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
    def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
        vad_latency = cache["windows_detector"].GetWinSize()
        if self.vad_opts.do_extend:
            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
        return vad_latency
    def GetFrameState(self, t: int, cache: dict = {}):
        frame_state = FrameState.kFrameStateInvalid
        cur_decibel = cache["stats"].decibel[t]
        cur_snr = cur_decibel - cache["stats"].noise_average_decibel
        # for each frame, calc log posterior probability of each state
        if cur_decibel < self.vad_opts.decibel_thres:
            frame_state = FrameState.kFrameStateSil
            self.DetectOneFrame(frame_state, t, False, cache=cache)
            return frame_state
        sum_score = 0.0
        noise_prob = 0.0
        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
        if len(cache["stats"].sil_pdf_ids) > 0:
            assert len(cache["stats"].scores) == 1  # 只支持batch_size = 1的测试
            sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
            sum_score = sum(sil_pdf_scores)
            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
            total_score = 1.0
            sum_score = total_score - sum_score
        speech_prob = math.log(sum_score)
        if self.vad_opts.output_frame_probs:
            frame_prob = E2EVadFrameProb()
            frame_prob.noise_prob = noise_prob
            frame_prob.speech_prob = speech_prob
            frame_prob.score = sum_score
            frame_prob.frame_id = t
            cache["stats"].frame_probs.append(frame_prob)
        if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
                frame_state = FrameState.kFrameStateSpeech
            else:
                frame_state = FrameState.kFrameStateSil
        else:
            frame_state = FrameState.kFrameStateSil
            if cache["stats"].noise_average_decibel < -99.9:
                cache["stats"].noise_average_decibel = cur_decibel
            else:
                cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
                        self.vad_opts.noise_frame_num_used_for_snr
                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
        return frame_state
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
                is_final: bool = False
                ):
        # if len(cache) == 0:
        #     self.AllResetDetection()
        # self.waveform = waveform  # compute decibel for each frame
        cache["stats"].waveform = waveform
        self.ComputeDecibel(cache=cache)
        self.ComputeScores(feats, cache=cache)
        if not is_final:
            self.DetectCommonFrames(cache=cache)
        else:
            self.DetectLastFrames(cache=cache)
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            if len(cache["stats"].output_data_buf) > 0:
                for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
                    if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
                        i].contain_seg_end_point):
                        continue
                    segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
                    segment_batch.append(segment)
                    cache["stats"].output_data_buf_offset += 1  # need update this parameter
            if segment_batch:
                segments.append(segment_batch)
        # if is_final:
        #     # reset class variables and clear the dict for the next query
        #     self.AllResetDetection()
        return segments
    def init_cache(self, cache: dict = {}, **kwargs):
        cache["frontend"] = {}
        cache["prev_samples"] = torch.empty(0)
        cache["encoder"] = {}
        windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                          self.vad_opts.sil_to_speech_time_thres,
                                          self.vad_opts.speech_to_sil_time_thres,
                                          self.vad_opts.frame_in_ms)
        stats = StatsItem(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
                          max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
                          speech_noise_thres=self.vad_opts.speech_noise_thres,
                      )
        cache["windows_detector"] = windows_detector
        cache["stats"] = stats
        return cache
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
                 tokenizer=None,
                 frontend=None,
                 cache: dict = {},
                 **kwargs,
                 ):
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)
        meta_data = {}
        chunk_size = kwargs.get("chunk_size", 60000) # 50ms
        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
        time1 = time.perf_counter()
        cfg = {"is_final": kwargs.get("is_final", False)}
        audio_sample_list = load_audio_text_image_video(data_in,
                                                        fs=frontend.fs,
                                                        audio_fs=kwargs.get("fs", 16000),
                                                        data_type=kwargs.get("data_type", "sound"),
                                                        tokenizer=tokenizer,
                                                        cache=cfg,
                                                        )
        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        assert len(audio_sample_list) == 1, "batch_size must be set 1"
        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
        segments = []
        for i in range(n):
            kwargs["is_final"] = _is_final and i == n - 1
            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
            # extract fbank feats
            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend, cache=cache["frontend"],
                                                   is_final=kwargs["is_final"])
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            speech = speech.to(device=kwargs["device"])
            speech_lengths = speech_lengths.to(device=kwargs["device"])
            batch = {
                "feats": speech,
                "waveform": cache["frontend"]["waveforms"],
                "is_final": kwargs["is_final"],
                "cache": cache
            }
            segments_i = self.forward(**batch)
            if len(segments_i) > 0:
                segments.extend(*segments_i)
        cache["prev_samples"] = audio_sample[:-m]
        if _is_final:
            self.init_cache(cache, **kwargs)
        ibest_writer = None
        if ibest_writer is None and kwargs.get("output_dir") is not None:
            writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = writer[f"{1}best_recog"]
        results = []
        result_i = {"key": key[0], "value": segments}
        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
            result_i = json.dumps(result_i)
        results.append(result_i)
        if ibest_writer is not None:
            ibest_writer["text"][key[0]] = segments
        return results, meta_data
    def DetectCommonFrames(self, cache: dict = {}) -> int:
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
        return 0
    def DetectLastFrames(self, cache: dict = {}) -> int:
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
            if i != 0:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
            else:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
        return 0
    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
        tmp_cur_frm_state = FrameState.kFrameStateInvalid
        if cur_frm_state == FrameState.kFrameStateSpeech:
            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
                tmp_cur_frm_state = FrameState.kFrameStateSpeech
            else:
                tmp_cur_frm_state = FrameState.kFrameStateSil
        elif cur_frm_state == FrameState.kFrameStateSil:
            tmp_cur_frm_state = FrameState.kFrameStateSil
        state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
        frm_shift_in_ms = self.vad_opts.frame_in_ms
        if AudioChangeState.kChangeStateSil2Speech == state_change:
            silence_frame_count = cache["stats"].continous_silence_frame_count
            cache["stats"].continous_silence_frame_count = 0
            cache["stats"].pre_end_silence_detected = False
            start_frame = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
                self.OnVoiceStart(start_frame, cache=cache)
                cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
                for t in range(start_frame + 1, cur_frm_idx + 1):
                    self.OnVoiceDetected(t, cache=cache)
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
                    self.OnVoiceDetected(t, cache=cache)
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
            cache["stats"].continous_silence_frame_count = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                pass
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
            cache["stats"].continous_silence_frame_count = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    cache["stats"].max_time_out = True
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        elif AudioChangeState.kChangeStateSil2Sil == state_change:
            cache["stats"].continous_silence_frame_count += 1
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                # silence timeout, return zero length decision
                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
                        cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
                        or (is_final_frame and cache["stats"].number_end_time_detected == 0):
                    for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
                        self.OnSilenceDetected(t, cache=cache)
                    self.OnVoiceStart(0, True, cache=cache)
                    self.OnVoiceEnd(0, True, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                else:
                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
                    lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
                    if self.vad_opts.do_extend:
                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
                        lookback_frame -= 1
                        lookback_frame = max(0, lookback_frame)
                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif self.vad_opts.do_extend and not is_final_frame:
                    if cache["stats"].continous_silence_frame_count <= int(
                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
                        self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
            self.ResetDetection(cache=cache)
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self,
                 encoder: str = None,
                 encoder_conf: Optional[Dict] = None,
                 vad_post_args: Dict[str, Any] = None,
                 **kwargs,
                 ):
        super().__init__()
        self.vad_opts = VADXOptions(**kwargs)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.encoder = encoder
    def ResetDetection(self, cache: dict = {}):
        cache["stats"].continous_silence_frame_count = 0
        cache["stats"].latest_confirmed_speech_frame = 0
        cache["stats"].lastest_confirmed_silence_frame = -1
        cache["stats"].confirmed_start_frame = -1
        cache["stats"].confirmed_end_frame = -1
        cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        cache["windows_detector"].Reset()
        cache["stats"].sil_frame = 0
        cache["stats"].frame_probs = []
        if cache["stats"].output_data_buf:
            assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
            cache["stats"].last_drop_frames = drop_frames
            cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
    def ComputeDecibel(self, cache: dict = {}) -> None:
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if cache["stats"].data_buf_all is None:
            cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
            cache["stats"].data_buf = cache["stats"].data_buf_all
        else:
            cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
        for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
            cache["stats"].decibel.append(
                10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                0.000001))
    def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
        scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
        if cache["stats"].scores is None:
            cache["stats"].scores = scores  # the first calculation
        else:
            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
    def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
        while cache["stats"].data_buf_start_frame < frame_idx:
            if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                cache["stats"].data_buf_start_frame += 1
                cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                           last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
        self.PopDataBufTillFrame(start_frm, cache=cache)
        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
        if last_frm_is_end_point:
            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
            expected_sample_number += int(extra_sample)
        if end_point_is_sent_end:
            expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
        if len(cache["stats"].data_buf) < expected_sample_number:
            print('error in calling pop data_buf\n')
        if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
            cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
            cache["stats"].output_data_buf[-1].Reset()
            cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
            cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
            cache["stats"].output_data_buf[-1].doa = 0
        cur_seg = cache["stats"].output_data_buf[-1]
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning\n')
        out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
        data_to_pop = 0
        if end_point_is_sent_end:
            data_to_pop = expected_sample_number
        else:
            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if data_to_pop > len(cache["stats"].data_buf):
            print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
            data_to_pop = len(cache["stats"].data_buf)
            expected_sample_number = len(cache["stats"].data_buf)
        cur_seg.doa = 0
        for sample_cpy_out in range(0, data_to_pop):
            # cur_seg.buffer[out_pos ++] = data_buf_.back();
            out_pos += 1
        for sample_cpy_out in range(data_to_pop, expected_sample_number):
            # cur_seg.buffer[out_pos++] = data_buf_.back()
            out_pos += 1
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('Something wrong with the VAD algorithm\n')
        cache["stats"].data_buf_start_frame += frm_cnt
        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
        if first_frm_is_start_point:
            cur_seg.contain_seg_start_point = True
        if last_frm_is_end_point:
            cur_seg.contain_seg_end_point = True
    def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
        cache["stats"].lastest_confirmed_silence_frame = valid_frame
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataBufTillFrame(valid_frame, cache=cache)
        # silence_detected_callback_
        # pass
    def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
        cache["stats"].latest_confirmed_speech_frame = valid_frame
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
        if self.vad_opts.do_start_point_detection:
            pass
        if cache["stats"].confirmed_start_frame != -1:
            print('not reset vad properly\n')
        else:
            cache["stats"].confirmed_start_frame = start_frame
        if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
            self.OnVoiceDetected(t, cache=cache)
        if self.vad_opts.do_end_point_detection:
            pass
        if cache["stats"].confirmed_end_frame != -1:
            print('not reset vad properly\n')
        else:
            cache["stats"].confirmed_end_frame = end_frame
        if not fake_result:
            cache["stats"].sil_frame = 0
            self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
        cache["stats"].number_end_time_detected += 1
    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
        if is_final_frame:
            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
    def GetLatency(self, cache: dict = {}) -> int:
        return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
    def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
        vad_latency = cache["windows_detector"].GetWinSize()
        if self.vad_opts.do_extend:
            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
        return vad_latency
    def GetFrameState(self, t: int, cache: dict = {}):
        frame_state = FrameState.kFrameStateInvalid
        cur_decibel = cache["stats"].decibel[t]
        cur_snr = cur_decibel - cache["stats"].noise_average_decibel
        # for each frame, calc log posterior probability of each state
        if cur_decibel < self.vad_opts.decibel_thres:
            frame_state = FrameState.kFrameStateSil
            self.DetectOneFrame(frame_state, t, False, cache=cache)
            return frame_state
        sum_score = 0.0
        noise_prob = 0.0
        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
        if len(cache["stats"].sil_pdf_ids) > 0:
            assert len(cache["stats"].scores) == 1  # 只支持batch_size = 1的测试
            sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
            sum_score = sum(sil_pdf_scores)
            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
            total_score = 1.0
            sum_score = total_score - sum_score
        speech_prob = math.log(sum_score)
        if self.vad_opts.output_frame_probs:
            frame_prob = E2EVadFrameProb()
            frame_prob.noise_prob = noise_prob
            frame_prob.speech_prob = speech_prob
            frame_prob.score = sum_score
            frame_prob.frame_id = t
            cache["stats"].frame_probs.append(frame_prob)
        if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
                frame_state = FrameState.kFrameStateSpeech
            else:
                frame_state = FrameState.kFrameStateSil
        else:
            frame_state = FrameState.kFrameStateSil
            if cache["stats"].noise_average_decibel < -99.9:
                cache["stats"].noise_average_decibel = cur_decibel
            else:
                cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
                    self.vad_opts.noise_frame_num_used_for_snr
                    - 1)) / self.vad_opts.noise_frame_num_used_for_snr
        return frame_state
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
                is_final: bool = False
                ):
        # if len(cache) == 0:
        #     self.AllResetDetection()
        # self.waveform = waveform  # compute decibel for each frame
        cache["stats"].waveform = waveform
        self.ComputeDecibel(cache=cache)
        self.ComputeScores(feats, cache=cache)
        if not is_final:
            self.DetectCommonFrames(cache=cache)
        else:
            self.DetectLastFrames(cache=cache)
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            if len(cache["stats"].output_data_buf) > 0:
                for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
                    if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
                        i].contain_seg_end_point):
                        continue
                    segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
                    segment_batch.append(segment)
                    cache["stats"].output_data_buf_offset += 1  # need update this parameter
            if segment_batch:
                segments.append(segment_batch)
        # if is_final:
        #     # reset class variables and clear the dict for the next query
        #     self.AllResetDetection()
        return segments
    def init_cache(self, cache: dict = {}, **kwargs):
        cache["frontend"] = {}
        cache["prev_samples"] = torch.empty(0)
        cache["encoder"] = {}
        windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                          self.vad_opts.sil_to_speech_time_thres,
                                          self.vad_opts.speech_to_sil_time_thres,
                                          self.vad_opts.frame_in_ms)
        windows_detector.Reset()
        stats = Stats(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
                      max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
                      speech_noise_thres=self.vad_opts.speech_noise_thres
                      )
        cache["windows_detector"] = windows_detector
        cache["stats"] = stats
        return cache
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  cache: dict = {},
                  **kwargs,
                  ):
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)
        meta_data = {}
        chunk_size = kwargs.get("chunk_size", 60000) # 50ms
        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
        time1 = time.perf_counter()
        cfg = {"is_final": kwargs.get("is_final", False)}
        audio_sample_list = load_audio_text_image_video(data_in,
                                                        fs=frontend.fs,
                                                        audio_fs=kwargs.get("fs", 16000),
                                                        data_type=kwargs.get("data_type", "sound"),
                                                        tokenizer=tokenizer,
                                                        cache=cfg,
                                                        )
        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        assert len(audio_sample_list) == 1, "batch_size must be set 1"
        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
        segments = []
        for i in range(n):
            kwargs["is_final"] = _is_final and i == n - 1
            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
            # extract fbank feats
            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend, cache=cache["frontend"],
                                                   is_final=kwargs["is_final"])
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            speech = speech.to(device=kwargs["device"])
            speech_lengths = speech_lengths.to(device=kwargs["device"])
            batch = {
                "feats": speech,
                "waveform": cache["frontend"]["waveforms"],
                "is_final": kwargs["is_final"],
                "cache": cache
            }
            segments_i = self.forward(**batch)
            if len(segments_i) > 0:
                segments.extend(*segments_i)
        cache["prev_samples"] = audio_sample[:-m]
        if _is_final:
            cache = {}
        ibest_writer = None
        if ibest_writer is None and kwargs.get("output_dir") is not None:
            writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = writer[f"{1}best_recog"]
        results = []
        result_i = {"key": key[0], "value": segments}
        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
            result_i = json.dumps(result_i)
        results.append(result_i)
        if ibest_writer is not None:
            ibest_writer["text"][key[0]] = segments
        return results, meta_data
    def DetectCommonFrames(self, cache: dict = {}) -> int:
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
        return 0
    def DetectLastFrames(self, cache: dict = {}) -> int:
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
            if i != 0:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
            else:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
        return 0
    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
        tmp_cur_frm_state = FrameState.kFrameStateInvalid
        if cur_frm_state == FrameState.kFrameStateSpeech:
            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
                tmp_cur_frm_state = FrameState.kFrameStateSpeech
            else:
                tmp_cur_frm_state = FrameState.kFrameStateSil
        elif cur_frm_state == FrameState.kFrameStateSil:
            tmp_cur_frm_state = FrameState.kFrameStateSil
        state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
        frm_shift_in_ms = self.vad_opts.frame_in_ms
        if AudioChangeState.kChangeStateSil2Speech == state_change:
            silence_frame_count = cache["stats"].continous_silence_frame_count
            cache["stats"].continous_silence_frame_count = 0
            cache["stats"].pre_end_silence_detected = False
            start_frame = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
                self.OnVoiceStart(start_frame, cache=cache)
                cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
                for t in range(start_frame + 1, cur_frm_idx + 1):
                    self.OnVoiceDetected(t, cache=cache)
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
                    self.OnVoiceDetected(t, cache=cache)
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                    self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
            cache["stats"].continous_silence_frame_count = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                pass
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                    self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
            cache["stats"].continous_silence_frame_count = 0
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                    self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    cache["stats"].max_time_out = True
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        elif AudioChangeState.kChangeStateSil2Sil == state_change:
            cache["stats"].continous_silence_frame_count += 1
            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                # silence timeout, return zero length decision
                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
                    cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
                    or (is_final_frame and cache["stats"].number_end_time_detected == 0):
                    for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
                        self.OnSilenceDetected(t, cache=cache)
                    self.OnVoiceStart(0, True, cache=cache)
                    self.OnVoiceEnd(0, True, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                else:
                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
                    lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
                    if self.vad_opts.do_extend:
                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
                        lookback_frame -= 1
                        lookback_frame = max(0, lookback_frame)
                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
                    self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif self.vad_opts.do_extend and not is_final_frame:
                    if cache["stats"].continous_silence_frame_count <= int(
                        self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
                        self.OnVoiceDetected(cur_frm_idx, cache=cache)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
            else:
                pass
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
            self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
            self.ResetDetection(cache=cache)
funasr/models/paraformer/model.py
@@ -33,7 +33,6 @@
    
    def __init__(
        self,
        # token_list: Union[Tuple[str, ...], List[str]],
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
funasr/models/paraformer/template.yaml
@@ -6,7 +6,6 @@
# tables.print()
# network architecture
#model: funasr.models.paraformer.model:Paraformer
model: Paraformer
model_conf:
    ctc_weight: 0.0
@@ -87,13 +86,6 @@
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  val_scheduler_criterion:
      - valid
      - acc
  best_model_criterion:
  -   - valid
      - acc
      - max
  keep_nbest_models: 10
  avg_nbest_model: 5
  log_interval: 50
funasr/models/sanm/decoder.py
@@ -1,3 +1,8 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from typing import List
from typing import Tuple
import logging
@@ -193,10 +198,9 @@
@tables.register("decoder_classes", "FsmnDecoder")
class FsmnDecoder(BaseTransformerDecoder):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """
    
    def __init__(
funasr/models/sanm/encoder.py
@@ -1,3 +1,8 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from typing import List
from typing import Optional
from typing import Sequence
@@ -156,10 +161,9 @@
@tables.register("encoder_classes", "SANMEncoder")
class SANMEncoder(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
funasr/models/sanm/model.py
@@ -1,3 +1,8 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import logging
import torch
@@ -7,7 +12,11 @@
@tables.register("model_classes", "SANM")
class SANM(Transformer):
    """CTC-attention hybrid Encoder-Decoder model"""
    """
    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
        self,
funasr/models/sanm/template.yaml
New file
@@ -0,0 +1,121 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: SANM
model_conf:
    ctc_weight: 0.0
    lsm_weight: 0.1
    length_normalized_loss: true
# encoder
encoder: SANMEncoder
encoder_conf:
    output_size: 512
    attention_heads: 4
    linear_units: 2048
    num_blocks: 50
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.1
    input_layer: pe
    pos_enc_class: SinusoidalPositionEncoder
    normalize_before: true
    kernel_size: 11
    sanm_shfit: 0
    selfattention_layer_type: sanm
# decoder
decoder: FsmnDecoder
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 16
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.1
    src_attention_dropout_rate: 0.1
    att_layer_num: 16
    kernel_size: 11
    sanm_shfit: 0
# frontend related
frontend: WavFrontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 7
    lfr_n: 6
specaug: SpecAugLFR
specaug_conf:
    apply_time_warp: false
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    lfr_rate: 6
    num_freq_mask: 1
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 12
    num_time_mask: 1
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  val_scheduler_criterion:
      - valid
      - acc
  best_model_criterion:
  -   - valid
      - acc
      - max
  keep_nbest_models: 10
  avg_nbest_model: 5
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: DynamicBatchLocalShuffleSampler
    batch_type: example # example or length
    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 500
    shuffle: True
    num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
  split_with_space: true
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null
funasr/models/scama/decoder.py
File was renamed from funasr/models/scama/sanm_decoder.py
@@ -1,3 +1,8 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from typing import List
from typing import Tuple
import logging
@@ -192,11 +197,11 @@
@tables.register("decoder_classes", "FsmnDecoderSCAMAOpt")
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    https://arxiv.org/abs/2006.01712
    """
    def __init__(
            self,
            vocab_size: int,
funasr/models/scama/encoder.py
File was renamed from funasr/models/scama/sanm_encoder.py
@@ -1,3 +1,8 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from typing import List
from typing import Optional
from typing import Sequence
@@ -157,10 +162,9 @@
@tables.register("encoder_classes", "SANMEncoderChunkOpt")
class SANMEncoderChunkOpt(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    https://arxiv.org/abs/2006.01712
    """
    def __init__(
funasr/models/scama/model.py
New file
@@ -0,0 +1,669 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import time
import torch
import torch.nn as nn
import torch.functional as F
import logging
from typing import Dict, Tuple
from contextlib import contextmanager
from distutils.version import LooseVersion
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.model import Paraformer
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.scama.utils import sequence_mask
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "SCAMA")
class SCAMA(nn.Module):
    """
    Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    https://arxiv.org/abs/2006.01712
    """
    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        predictor: str = None,
        predictor_conf: dict = None,
        predictor_bias: int = 0,
        predictor_weight: float = 0.0,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        share_embedding: bool = False,
        **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        decoder_class = tables.decoder_classes.get(decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            **decoder_conf,
        )
        if ctc_weight > 0.0:
            if ctc_conf is None:
                ctc_conf = {}
            ctc = CTC(
                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
            )
        predictor_class = tables.predictor_classes.get(predictor)
        predictor = predictor_class(**predictor_conf)
        # note that eos is the same as sos (equivalent ID)
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        if ctc_weight == 1.0:
            self.decoder = None
        else:
            self.decoder = decoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc
        self.predictor = predictor
        self.predictor_weight = predictor_weight
        self.predictor_bias = predictor_bias
        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
        self.share_embedding = share_embedding
        if self.share_embedding:
            self.decoder.embed = None
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
        self.error_calculator = None
        if self.encoder.overlap_chunk_cls is not None:
            from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
            self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
            self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        decoding_ind = kwargs.get("decoding_ind")
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # Encoder
        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()
        # decoder: CTC branch
        if self.ctc_weight > 0.0:
            encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
                                                                                                encoder_out_lens,
                                                                                                chunk_outs=None)
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # decoder: Attention decoder branch
        loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att + loss_pre * self.predictor_weight
        else:
            loss = self.ctc_weight * loss_ctc + (
                1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = (text_lengths + self.predictor_bias).sum()
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        return encoder_out, encoder_out_lens
    def encode_chunk(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        return encoder_out, torch.tensor([encoder_out.size(1)])
    def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
        is_final = kwargs.get("is_final", False)
        return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
    def _calc_att_predictor_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1
        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
                                         device=encoder_out.device)[:, None, :]
        mask_chunk_predictor = None
        if self.encoder.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
                                                                                           device=encoder_out.device,
                                                                                           batch_size=encoder_out.size(
                                                                                               0))
            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
                                                                                   batch_size=encoder_out.size(0))
            encoder_out = encoder_out * mask_shfit_chunk
        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
                                                                              ys_out_pad,
                                                                              encoder_out_mask,
                                                                              ignore_id=self.ignore_id,
                                                                              mask_chunk_predictor=mask_chunk_predictor,
                                                                              target_label_length=ys_in_lens,
                                                                              )
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
                                                                                             encoder_out_lens)
        encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
        attention_chunk_center_bias = 0
        attention_chunk_size = encoder_chunk_size
        decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
        mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
                                                                                                       device=encoder_out.device,
                                                                                                       batch_size=encoder_out.size(
                                                                                                           0))
        scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
            predictor_alignments=predictor_alignments,
            encoder_sequence_length=encoder_out_lens,
            chunk_size=1,
            encoder_chunk_size=encoder_chunk_size,
            attention_chunk_center_bias=attention_chunk_center_bias,
            attention_chunk_size=attention_chunk_size,
            attention_chunk_type=self.decoder_attention_chunk_type,
            step=None,
            predictor_mask_chunk_hopping=mask_chunk_predictor,
            decoder_att_look_back_factor=decoder_att_look_back_factor,
            mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
            target_length=ys_in_lens,
            is_training=self.training,
        )
        # try:
        # 1. Forward decoder
        decoder_out, _ = self.decoder(
            encoder_out,
            encoder_out_lens,
            ys_in_pad,
            ys_in_lens,
            chunk_mask=scama_mask,
            pre_acoustic_embeds=pre_acoustic_embeds,
        )
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        # predictor loss
        loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
        return loss_att, acc_att, cer_att, wer_att, loss_pre
    def calc_predictor_mask(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor = None,
        ys_pad_lens: torch.Tensor = None,
    ):
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_in_lens = ys_pad_lens + 1
        ys_out_pad, ys_in_lens = None, None
        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
                                         device=encoder_out.device)[:, None, :]
        mask_chunk_predictor = None
        mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
                                                                                       device=encoder_out.device,
                                                                                       batch_size=encoder_out.size(
                                                                                           0))
        mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
                                                                               batch_size=encoder_out.size(0))
        encoder_out = encoder_out * mask_shfit_chunk
        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
                                                                              ys_out_pad,
                                                                              encoder_out_mask,
                                                                              ignore_id=self.ignore_id,
                                                                              mask_chunk_predictor=mask_chunk_predictor,
                                                                              target_label_length=ys_in_lens,
                                                                              )
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
                                                                                             encoder_out_lens)
        encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
        attention_chunk_center_bias = 0
        attention_chunk_size = encoder_chunk_size
        decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
        mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
                                                                                                       device=encoder_out.device,
                                                                                                       batch_size=encoder_out.size(
                                                                                                           0))
        scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
            predictor_alignments=predictor_alignments,
            encoder_sequence_length=encoder_out_lens,
            chunk_size=1,
            encoder_chunk_size=encoder_chunk_size,
            attention_chunk_center_bias=attention_chunk_center_bias,
            attention_chunk_size=attention_chunk_size,
            attention_chunk_type=self.decoder_attention_chunk_type,
            step=None,
            predictor_mask_chunk_hopping=mask_chunk_predictor,
            decoder_att_look_back_factor=decoder_att_look_back_factor,
            mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
            target_length=ys_in_lens,
            is_training=self.training,
        )
        return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
    def init_beam_search(self,
                         **kwargs,
                         ):
        from funasr.models.scama.beam_search import BeamSearchScama
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # 1. Build ASR model
        scorers = {}
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
            ctc=kwargs.get("decoding_ctc_weight", 0.0),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearchScama(
            beam_size=kwargs.get("beam_size", 2),
            weights=weights,
            scorers=scorers,
            sos=self.sos,
            eos=self.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
        )
        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        # for scorer in scorers.values():
        #     if isinstance(scorer, torch.nn.Module):
        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        self.beam_search = beam_search
    def generate_chunk(self,
                       speech,
                       speech_lengths=None,
                       key: list = None,
                       tokenizer=None,
                       frontend=None,
                       **kwargs,
                       ):
        cache = kwargs.get("cache", {})
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache,
                                                          is_final=kwargs.get("is_final", False))
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        # predictor
        predictor_outs = self.calc_predictor_chunk(encoder_out,
                                                   encoder_out_lens,
                                                   cache=cache,
                                                   is_final=kwargs.get("is_final", False),
                                                   )
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                        predictor_outs[2], predictor_outs[3]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
                                                             encoder_out_lens,
                                                             pre_acoustic_embeds,
                                                             pre_token_length,
                                                             cache=cache
                                                             )
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        results = []
        b, n, d = decoder_out.size()
        if isinstance(key[0], (list, tuple)):
            key = key[0]
        for i in range(b):
            x = encoder_out[i, :encoder_out_lens[i], :]
            am_scores = decoder_out[i, :pre_token_length[i], :]
            if self.beam_search is not None:
                nbest_hyps = self.beam_search(
                    x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
                    minlenratio=kwargs.get("minlenratio", 0.0)
                )
                nbest_hyps = nbest_hyps[: self.nbest]
            else:
                yseq = am_scores.argmax(dim=-1)
                score = am_scores.max(dim=-1)[0]
                score = torch.sum(score, dim=-1)
                # pad with mask tokens to ensure compatibility with sos/eos tokens
                yseq = torch.tensor(
                    [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
                )
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for nbest_idx, hyp in enumerate(nbest_hyps):
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
                # Change integer-ids to tokens
                token = tokenizer.ids2tokens(token_int)
                # text = tokenizer.tokens2text(token)
                result_i = token
                results.extend(result_i)
        return results
    def init_cache(self, cache: dict = {}, **kwargs):
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1
        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
                         "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
                         "tail_chunk": False}
        cache["encoder"] = cache_encoder
        cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
                         "chunk_size": chunk_size}
        cache["decoder"] = cache_decoder
        cache["frontend"] = {}
        cache["prev_samples"] = torch.empty(0)
        return cache
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  cache: dict = {},
                  **kwargs,
                  ):
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)
        meta_data = {}
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        chunk_stride_samples = int(chunk_size[1] * 960)  # 600ms
        time1 = time.perf_counter()
        cfg = {"is_final": kwargs.get("is_final", False)}
        audio_sample_list = load_audio_text_image_video(data_in,
                                                        fs=frontend.fs,
                                                        audio_fs=kwargs.get("fs", 16000),
                                                        data_type=kwargs.get("data_type", "sound"),
                                                        tokenizer=tokenizer,
                                                        cache=cfg,
                                                        )
        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        assert len(audio_sample_list) == 1, "batch_size must be set 1"
        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
        tokens = []
        for i in range(n):
            kwargs["is_final"] = _is_final and i == n - 1
            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
            # extract fbank feats
            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend, cache=cache["frontend"],
                                                   is_final=kwargs["is_final"])
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache,
                                           frontend=frontend, **kwargs)
            tokens.extend(tokens_i)
        text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
        result_i = {"key": key[0], "text": text_postprocessed}
        result = [result_i]
        cache["prev_samples"] = audio_sample[:-m]
        if _is_final:
            self.init_cache(cache, **kwargs)
        if kwargs.get("output_dir"):
            writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = writer[f"{1}best_recog"]
            ibest_writer["token"][key[0]] = " ".join(tokens)
            ibest_writer["text"][key[0]] = text_postprocessed
        return result, meta_data
funasr/models/scama/template.yaml
New file
@@ -0,0 +1,127 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: SCAMA
model_conf:
    ctc_weight: 0.0
    lsm_weight: 0.1
    length_normalized_loss: true
# encoder
encoder: SANMEncoderChunkOpt
encoder_conf:
    output_size: 512
    attention_heads: 4
    linear_units: 2048
    num_blocks: 50
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.1
    input_layer: pe
    pos_enc_class: SinusoidalPositionEncoder
    normalize_before: true
    kernel_size: 11
    sanm_shfit: 0
    selfattention_layer_type: sanm
# decoder
decoder: FsmnDecoderSCAMAOpt
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 16
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.1
    src_attention_dropout_rate: 0.1
    att_layer_num: 16
    kernel_size: 11
    sanm_shfit: 0
predictor: CifPredictorV2
predictor_conf:
    idim: 512
    threshold: 1.0
    l_order: 1
    r_order: 1
    tail_threshold: 0.45
# frontend related
frontend: WavFrontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 7
    lfr_n: 6
specaug: SpecAugLFR
specaug_conf:
    apply_time_warp: false
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    lfr_rate: 6
    num_freq_mask: 1
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 12
    num_time_mask: 1
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  val_scheduler_criterion:
      - valid
      - acc
  best_model_criterion:
  -   - valid
      - acc
      - max
  keep_nbest_models: 10
  avg_nbest_model: 5
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: DynamicBatchLocalShuffleSampler
    batch_type: example # example or length
    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 500
    shuffle: True
    num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
  split_with_space: true
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null
funasr/models/uniasr/model.py
File was renamed from funasr/models/uniasr/e2e_uni_asr.py
@@ -1,85 +1,73 @@
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import time
import torch
import logging
from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.scama.chunk_utilis import sequence_mask
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
class UniASR(FunASRModel):
@tables.register("model_classes", "UniASR")
class UniASR(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    """
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: AbsEncoder,
        decoder: AbsDecoder,
        ctc: CTC,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        predictor: str = None,
        predictor_conf: Optional[Dict] = None,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        extract_feats_in_collect_stats: bool = True,
        predictor=None,
        # report_cer: bool = True,
        # report_wer: bool = True,
        # sym_space: str = "<space>",
        # sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        # predictor=None,
        predictor_weight: float = 0.0,
        decoder_attention_chunk_type: str = 'chunk',
        encoder2: AbsEncoder = None,
        decoder2: AbsDecoder = None,
        ctc2: CTC = None,
        ctc_weight2: float = 0.5,
        interctc_weight2: float = 0.0,
        predictor2=None,
        predictor_weight2: float = 0.0,
        decoder_attention_chunk_type2: str = 'chunk',
        stride_conv=None,
        loss_weight_model1: float = 0.5,
        enable_maas_finetune: bool = False,
        freeze_encoder2: bool = False,
        preencoder: Optional[AbsPreEncoder] = None,
        postencoder: Optional[AbsPostEncoder] = None,
        predictor_bias: int = 0,
        sampling_ratio: float = 0.2,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        use_1st_decoder_loss: bool = False,
        encoder1_encoder2_joint_training: bool = True,
        **kwargs,
    ):
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -443,10 +431,8 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
<<<<<<< HEAD:funasr/models/uniasr/e2e_uni_asr.py
=======
>>>>>>> main:funasr/models/e2e_uni_asr.py
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
funasr/models/uniasr/template.yaml
New file
@@ -0,0 +1,178 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: UniASR
model_conf:
    ctc_weight: 0.0
    lsm_weight: 0.1
    length_normalized_loss: true
    predictor_weight: 1.0
    decoder_attention_chunk_type: chunk
    ctc_weight2: 0.0
    predictor_weight2: 1.0
    decoder_attention_chunk_type2: chunk
    loss_weight_model1: 0.5
# encoder
encoder: SANMEncoderChunkOpt
encoder_conf:
    output_size: 320
    attention_heads: 4
    linear_units: 1280
    num_blocks: 35
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.1
    input_layer: pe
    pos_enc_class: SinusoidalPositionEncoder
    normalize_before: true
    kernel_size: 11
    sanm_shfit: 0
    selfattention_layer_type: sanm
    chunk_size: [20, 60]
    stride: [10, 40]
    pad_left: [5, 10]
    encoder_att_look_back_factor: [0, 0]
    decoder_att_look_back_factor: [0, 0]
# decoder
decoder: FsmnDecoderSCAMAOpt
decoder_conf:
    attention_dim: 256
    attention_heads: 4
    linear_units: 1024
    num_blocks: 12
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.1
    src_attention_dropout_rate: 0.1
    att_layer_num: 6
    kernel_size: 11
    concat_embeds: true
predictor: CifPredictorV2
predictor_conf:
    idim: 320
    threshold: 1.0
    l_order: 1
    r_order: 1
encoder2: SANMEncoderChunkOpt
encoder2_conf:
    output_size: 320
    attention_heads: 4
    linear_units: 1280
    num_blocks: 20
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.1
    input_layer: pe
    pos_enc_class: SinusoidalPositionEncoder
    normalize_before: true
    kernel_size: 21
    sanm_shfit: 0
    selfattention_layer_type: sanm
    chunk_size: [45, 70]
    stride: [35, 50]
    pad_left: [5, 10]
    encoder_att_look_back_factor: [0, 0]
    decoder_att_look_back_factor: [0, 0]
decoder2: FsmnDecoderSCAMAOpt
decoder2_conf:
    attention_dim: 320
    attention_heads: 4
    linear_units: 1280
    num_blocks: 12
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.1
    src_attention_dropout_rate: 0.1
    att_layer_num: 6
    kernel_size: 11
    concat_embeds: true
predictor2: CifPredictorV2
predictor2_conf:
    idim: 320
    threshold: 1.0
    l_order: 1
    r_order: 1
stride_conv: stride_conv1d
stride_conv_conf:
    kernel_size: 2
    stride: 2
    pad: [0, 1]
# frontend related
frontend: WavFrontendOnline
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 7
    lfr_n: 6
specaug: SpecAugLFR
specaug_conf:
    apply_time_warp: false
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    lfr_rate: 6
    num_freq_mask: 1
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 12
    num_time_mask: 1
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  keep_nbest_models: 10
  avg_nbest_model: 5
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: DynamicBatchLocalShuffleSampler
    batch_type: example # example or length
    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 500
    shuffle: True
    num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
  split_with_space: true
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null