From b28f3c9da94ae72a3a0b7bb5982b587be7cf4cd6 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 18 一月 2024 22:00:58 +0800
Subject: [PATCH] fsmn-vad bugfix (#1270)

---
 funasr/models/paraformer/model.py                                       |    1 
 funasr/models/fsmn_vad_streaming/model.py                               | 1374 ++++++++++++++--------------
 funasr/models/sanm/model.py                                             |   11 
 funasr/models/scama/decoder.py                                          |   11 
 examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh |   12 
 README_zh.md                                                            |   11 
 funasr/models/uniasr/model.py                                           |  118 +-
 README.md                                                               |   11 
 funasr/models/scama/template.yaml                                       |  127 ++
 funasr/models/paraformer/template.yaml                                  |    8 
 examples/industrial_data_pretraining/scama/demo.py                      |   42 
 funasr/models/scama/model.py                                            |  669 ++++++++++++++
 examples/industrial_data_pretraining/scama/infer.sh                     |   11 
 funasr/models/sanm/decoder.py                                           |   10 
 funasr/models/sanm/template.yaml                                        |  121 ++
 funasr/models/sanm/encoder.py                                           |    8 
 funasr/models/uniasr/template.yaml                                      |  178 +++
 funasr/models/scama/encoder.py                                          |   10 
 18 files changed, 1,951 insertions(+), 782 deletions(-)

diff --git a/README.md b/README.md
index 0094dc4..c9b9e89 100644
--- a/README.md
+++ b/README.md
@@ -91,12 +91,13 @@
 from funasr import AutoModel
 # paraformer-zh is a multi-functional asr model
 # use vad, punc, spk or not as you need
-model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
-                  vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
-                  punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
-                  spk_model="cam++", spk_model_revision="v2.0.2")
+model = AutoModel(model="paraformer-zh", model_revision="v2.0.2",
+                  vad_model="fsmn-vad", vad_model_revision="v2.0.2",
+                  punc_model="ct-punc-c", punc_model_revision="v2.0.2",
+                  # spk_model="cam++", spk_model_revision="v2.0.2",
+                  )
 res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
-                     batch_size=64, 
+                     batch_size_s=300, 
                      hotword='榄旀惌')
 print(res)
 ```
diff --git a/README_zh.md b/README_zh.md
index 57a6bbb..9cd1897 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -87,12 +87,13 @@
 from funasr import AutoModel
 # paraformer-zh is a multi-functional asr model
 # use vad, punc, spk or not as you need
-model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
-                  vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
-                  punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
-                  spk_model="cam++", spk_model_revision="v2.0.2")
+model = AutoModel(model="paraformer-zh", model_revision="v2.0.2",
+                  vad_model="fsmn-vad", vad_model_revision="v2.0.2",
+                  punc_model="ct-punc-c", punc_model_revision="v2.0.2",
+                  # spk_model="cam++", spk_model_revision="v2.0.2",
+                  )
 res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
-            batch_size=64, 
+            batch_size_s=300, 
             hotword='榄旀惌')
 print(res)
 ```
diff --git a/examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh b/examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh
new file mode 100644
index 0000000..df1e54a
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh
@@ -0,0 +1,12 @@
+
+
+python funasr/bin/inference.py \
+--config-path="/Users/zhifu/funasr_github/test_local/funasr_cli_egs" \
+--config-name="config.yaml" \
+++init_param="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/model.pt" \
++tokenizer_conf.token_list="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/tokens.txt" \
++frontend_conf.cmvn_file="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/am.mvn" \
++input="data/wav.scp" \
++output_dir="./outputs/debug" \
++device="cuda" \
+
diff --git a/examples/industrial_data_pretraining/scama/demo.py b/examples/industrial_data_pretraining/scama/demo.py
new file mode 100644
index 0000000..c805993
--- /dev/null
+++ b/examples/industrial_data_pretraining/scama/demo.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
+encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
+decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
+
+model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_SCAMA_asr-zh-cn-16k-common-vocab8358-streaming", model_revision="v2.0.2")
+cache = {}
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+            chunk_size=chunk_size,
+            encoder_chunk_look_back=encoder_chunk_look_back,
+            decoder_chunk_look_back=decoder_chunk_look_back,
+            )
+print(res)
+
+
+import soundfile
+import os
+
+wav_file = os.path.join(model.model_path, "example/asr_example.wav")
+speech, sample_rate = soundfile.read(wav_file)
+
+chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
+
+cache = {}
+total_chunk_num = int(len((speech)-1)/chunk_stride+1)
+for i in range(total_chunk_num):
+    speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
+    is_final = i == total_chunk_num - 1
+    res = model.generate(input=speech_chunk,
+                         cache=cache,
+                         is_final=is_final,
+                         chunk_size=chunk_size,
+                         encoder_chunk_look_back=encoder_chunk_look_back,
+                         decoder_chunk_look_back=decoder_chunk_look_back,
+                         )
+    print(res)
diff --git a/examples/industrial_data_pretraining/scama/infer.sh b/examples/industrial_data_pretraining/scama/infer.sh
new file mode 100644
index 0000000..225f2a9
--- /dev/null
+++ b/examples/industrial_data_pretraining/scama/infer.sh
@@ -0,0 +1,11 @@
+
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+model_revision="v2.0.2"
+
+python funasr/bin/inference.py \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
++output_dir="./outputs/debug" \
++device="cpu" \
+
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 193feb0..943cb47 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -19,714 +19,718 @@
 
 
 class VadStateMachine(Enum):
-    kVadInStateStartPointNotDetected = 1
-    kVadInStateInSpeechSegment = 2
-    kVadInStateEndPointDetected = 3
+	kVadInStateStartPointNotDetected = 1
+	kVadInStateInSpeechSegment = 2
+	kVadInStateEndPointDetected = 3
 
 class FrameState(Enum):
-    kFrameStateInvalid = -1
-    kFrameStateSpeech = 1
-    kFrameStateSil = 0
+	kFrameStateInvalid = -1
+	kFrameStateSpeech = 1
+	kFrameStateSil = 0
 
 # final voice/unvoice state per frame
 class AudioChangeState(Enum):
-    kChangeStateSpeech2Speech = 0
-    kChangeStateSpeech2Sil = 1
-    kChangeStateSil2Sil = 2
-    kChangeStateSil2Speech = 3
-    kChangeStateNoBegin = 4
-    kChangeStateInvalid = 5
+	kChangeStateSpeech2Speech = 0
+	kChangeStateSpeech2Sil = 1
+	kChangeStateSil2Sil = 2
+	kChangeStateSil2Speech = 3
+	kChangeStateNoBegin = 4
+	kChangeStateInvalid = 5
 
 class VadDetectMode(Enum):
-    kVadSingleUtteranceDetectMode = 0
-    kVadMutipleUtteranceDetectMode = 1
+	kVadSingleUtteranceDetectMode = 0
+	kVadMutipleUtteranceDetectMode = 1
 
 class VADXOptions:
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(
-            self,
-            sample_rate: int = 16000,
-            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
-            snr_mode: int = 0,
-            max_end_silence_time: int = 800,
-            max_start_silence_time: int = 3000,
-            do_start_point_detection: bool = True,
-            do_end_point_detection: bool = True,
-            window_size_ms: int = 200,
-            sil_to_speech_time_thres: int = 150,
-            speech_to_sil_time_thres: int = 150,
-            speech_2_noise_ratio: float = 1.0,
-            do_extend: int = 1,
-            lookback_time_start_point: int = 200,
-            lookahead_time_end_point: int = 100,
-            max_single_segment_time: int = 60000,
-            nn_eval_block_size: int = 8,
-            dcd_block_size: int = 4,
-            snr_thres: int = -100.0,
-            noise_frame_num_used_for_snr: int = 100,
-            decibel_thres: int = -100.0,
-            speech_noise_thres: float = 0.6,
-            fe_prior_thres: float = 1e-4,
-            silence_pdf_num: int = 1,
-            sil_pdf_ids: List[int] = [0],
-            speech_noise_thresh_low: float = -0.1,
-            speech_noise_thresh_high: float = 0.3,
-            output_frame_probs: bool = False,
-            frame_in_ms: int = 10,
-            frame_length_ms: int = 25,
-            **kwargs,
-    ):
-        self.sample_rate = sample_rate
-        self.detect_mode = detect_mode
-        self.snr_mode = snr_mode
-        self.max_end_silence_time = max_end_silence_time
-        self.max_start_silence_time = max_start_silence_time
-        self.do_start_point_detection = do_start_point_detection
-        self.do_end_point_detection = do_end_point_detection
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time_thres = sil_to_speech_time_thres
-        self.speech_to_sil_time_thres = speech_to_sil_time_thres
-        self.speech_2_noise_ratio = speech_2_noise_ratio
-        self.do_extend = do_extend
-        self.lookback_time_start_point = lookback_time_start_point
-        self.lookahead_time_end_point = lookahead_time_end_point
-        self.max_single_segment_time = max_single_segment_time
-        self.nn_eval_block_size = nn_eval_block_size
-        self.dcd_block_size = dcd_block_size
-        self.snr_thres = snr_thres
-        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
-        self.decibel_thres = decibel_thres
-        self.speech_noise_thres = speech_noise_thres
-        self.fe_prior_thres = fe_prior_thres
-        self.silence_pdf_num = silence_pdf_num
-        self.sil_pdf_ids = sil_pdf_ids
-        self.speech_noise_thresh_low = speech_noise_thresh_low
-        self.speech_noise_thresh_high = speech_noise_thresh_high
-        self.output_frame_probs = output_frame_probs
-        self.frame_in_ms = frame_in_ms
-        self.frame_length_ms = frame_length_ms
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(
+		self,
+		sample_rate: int = 16000,
+		detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+		snr_mode: int = 0,
+		max_end_silence_time: int = 800,
+		max_start_silence_time: int = 3000,
+		do_start_point_detection: bool = True,
+		do_end_point_detection: bool = True,
+		window_size_ms: int = 200,
+		sil_to_speech_time_thres: int = 150,
+		speech_to_sil_time_thres: int = 150,
+		speech_2_noise_ratio: float = 1.0,
+		do_extend: int = 1,
+		lookback_time_start_point: int = 200,
+		lookahead_time_end_point: int = 100,
+		max_single_segment_time: int = 60000,
+		nn_eval_block_size: int = 8,
+		dcd_block_size: int = 4,
+		snr_thres: int = -100.0,
+		noise_frame_num_used_for_snr: int = 100,
+		decibel_thres: int = -100.0,
+		speech_noise_thres: float = 0.6,
+		fe_prior_thres: float = 1e-4,
+		silence_pdf_num: int = 1,
+		sil_pdf_ids: List[int] = [0],
+		speech_noise_thresh_low: float = -0.1,
+		speech_noise_thresh_high: float = 0.3,
+		output_frame_probs: bool = False,
+		frame_in_ms: int = 10,
+		frame_length_ms: int = 25,
+		**kwargs,
+	):
+		self.sample_rate = sample_rate
+		self.detect_mode = detect_mode
+		self.snr_mode = snr_mode
+		self.max_end_silence_time = max_end_silence_time
+		self.max_start_silence_time = max_start_silence_time
+		self.do_start_point_detection = do_start_point_detection
+		self.do_end_point_detection = do_end_point_detection
+		self.window_size_ms = window_size_ms
+		self.sil_to_speech_time_thres = sil_to_speech_time_thres
+		self.speech_to_sil_time_thres = speech_to_sil_time_thres
+		self.speech_2_noise_ratio = speech_2_noise_ratio
+		self.do_extend = do_extend
+		self.lookback_time_start_point = lookback_time_start_point
+		self.lookahead_time_end_point = lookahead_time_end_point
+		self.max_single_segment_time = max_single_segment_time
+		self.nn_eval_block_size = nn_eval_block_size
+		self.dcd_block_size = dcd_block_size
+		self.snr_thres = snr_thres
+		self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+		self.decibel_thres = decibel_thres
+		self.speech_noise_thres = speech_noise_thres
+		self.fe_prior_thres = fe_prior_thres
+		self.silence_pdf_num = silence_pdf_num
+		self.sil_pdf_ids = sil_pdf_ids
+		self.speech_noise_thresh_low = speech_noise_thresh_low
+		self.speech_noise_thresh_high = speech_noise_thresh_high
+		self.output_frame_probs = output_frame_probs
+		self.frame_in_ms = frame_in_ms
+		self.frame_length_ms = frame_length_ms
 
 
 class E2EVadSpeechBufWithDoa(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
-
-    def Reset(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self):
+		self.start_ms = 0
+		self.end_ms = 0
+		self.buffer = []
+		self.contain_seg_start_point = False
+		self.contain_seg_end_point = False
+		self.doa = 0
+	
+	def Reset(self):
+		self.start_ms = 0
+		self.end_ms = 0
+		self.buffer = []
+		self.contain_seg_start_point = False
+		self.contain_seg_end_point = False
+		self.doa = 0
 
 
 class E2EVadFrameProb(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.noise_prob = 0.0
-        self.speech_prob = 0.0
-        self.score = 0.0
-        self.frame_id = 0
-        self.frm_state = 0
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self):
+		self.noise_prob = 0.0
+		self.speech_prob = 0.0
+		self.score = 0.0
+		self.frame_id = 0
+		self.frm_state = 0
 
 
 class WindowDetector(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self, window_size_ms: int,
-                 sil_to_speech_time: int,
-                 speech_to_sil_time: int,
-                 frame_size_ms: int):
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time = sil_to_speech_time
-        self.speech_to_sil_time = speech_to_sil_time
-        self.frame_size_ms = frame_size_ms
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self, window_size_ms: int,
+	             sil_to_speech_time: int,
+	             speech_to_sil_time: int,
+	             frame_size_ms: int):
+		self.window_size_ms = window_size_ms
+		self.sil_to_speech_time = sil_to_speech_time
+		self.speech_to_sil_time = speech_to_sil_time
+		self.frame_size_ms = frame_size_ms
+		
+		self.win_size_frame = int(window_size_ms / frame_size_ms)
+		self.win_sum = 0
+		self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
+		
+		self.cur_win_pos = 0
+		self.pre_frame_state = FrameState.kFrameStateSil
+		self.cur_frame_state = FrameState.kFrameStateSil
+		self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+		self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+		
+		self.voice_last_frame_count = 0
+		self.noise_last_frame_count = 0
+		self.hydre_frame_count = 0
+	
+	def Reset(self) -> None:
+		self.cur_win_pos = 0
+		self.win_sum = 0
+		self.win_state = [0] * self.win_size_frame
+		self.pre_frame_state = FrameState.kFrameStateSil
+		self.cur_frame_state = FrameState.kFrameStateSil
+		self.voice_last_frame_count = 0
+		self.noise_last_frame_count = 0
+		self.hydre_frame_count = 0
+	
+	def GetWinSize(self) -> int:
+		return int(self.win_size_frame)
+	
+	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
+		cur_frame_state = FrameState.kFrameStateSil
+		if frameState == FrameState.kFrameStateSpeech:
+			cur_frame_state = 1
+		elif frameState == FrameState.kFrameStateSil:
+			cur_frame_state = 0
+		else:
+			return AudioChangeState.kChangeStateInvalid
+		self.win_sum -= self.win_state[self.cur_win_pos]
+		self.win_sum += cur_frame_state
+		self.win_state[self.cur_win_pos] = cur_frame_state
+		self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+		
+		if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+			self.pre_frame_state = FrameState.kFrameStateSpeech
+			return AudioChangeState.kChangeStateSil2Speech
+		
+		if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+			self.pre_frame_state = FrameState.kFrameStateSil
+			return AudioChangeState.kChangeStateSpeech2Sil
+		
+		if self.pre_frame_state == FrameState.kFrameStateSil:
+			return AudioChangeState.kChangeStateSil2Sil
+		if self.pre_frame_state == FrameState.kFrameStateSpeech:
+			return AudioChangeState.kChangeStateSpeech2Speech
+		return AudioChangeState.kChangeStateInvalid
+	
+	def FrameSizeMs(self) -> int:
+		return int(self.frame_size_ms)
 
-        self.win_size_frame = int(window_size_ms / frame_size_ms)
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
-
-        self.cur_win_pos = 0
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
-        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
-
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def Reset(self) -> None:
-        self.cur_win_pos = 0
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def GetWinSize(self) -> int:
-        return int(self.win_size_frame)
-
-    def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
-        cur_frame_state = FrameState.kFrameStateSil
-        if frameState == FrameState.kFrameStateSpeech:
-            cur_frame_state = 1
-        elif frameState == FrameState.kFrameStateSil:
-            cur_frame_state = 0
-        else:
-            return AudioChangeState.kChangeStateInvalid
-        self.win_sum -= self.win_state[self.cur_win_pos]
-        self.win_sum += cur_frame_state
-        self.win_state[self.cur_win_pos] = cur_frame_state
-        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
-
-        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSpeech
-            return AudioChangeState.kChangeStateSil2Speech
-
-        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSil
-            return AudioChangeState.kChangeStateSpeech2Sil
-
-        if self.pre_frame_state == FrameState.kFrameStateSil:
-            return AudioChangeState.kChangeStateSil2Sil
-        if self.pre_frame_state == FrameState.kFrameStateSpeech:
-            return AudioChangeState.kChangeStateSpeech2Speech
-        return AudioChangeState.kChangeStateInvalid
-
-    def FrameSizeMs(self) -> int:
-        return int(self.frame_size_ms)
+class Stats(object):
+	def __init__(self,
+	             sil_pdf_ids,
+	             max_end_sil_frame_cnt_thresh,
+	             speech_noise_thres,
+	             ):
+		
+		self.data_buf_start_frame = 0
+		self.frm_cnt = 0
+		self.latest_confirmed_speech_frame = 0
+		self.lastest_confirmed_silence_frame = -1
+		self.continous_silence_frame_count = 0
+		self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+		self.confirmed_start_frame = -1
+		self.confirmed_end_frame = -1
+		self.number_end_time_detected = 0
+		self.sil_frame = 0
+		self.sil_pdf_ids = sil_pdf_ids
+		self.noise_average_decibel = -100.0
+		self.pre_end_silence_detected = False
+		self.next_seg = True
+		
+		self.output_data_buf = []
+		self.output_data_buf_offset = 0
+		self.frame_probs = []
+		self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
+		self.speech_noise_thres = speech_noise_thres
+		self.scores = None
+		self.max_time_out = False
+		self.decibel = []
+		self.data_buf = None
+		self.data_buf_all = None
+		self.waveform = None
+		self.last_drop_frames = 0
 
 
-@dataclass
-class StatsItem:
-    
-    # init variables
-    data_buf_start_frame = 0
-    frm_cnt = 0
-    latest_confirmed_speech_frame = 0
-    lastest_confirmed_silence_frame = -1
-    continous_silence_frame_count = 0
-    vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
-    confirmed_start_frame = -1
-    confirmed_end_frame = -1
-    number_end_time_detected = 0
-    sil_frame = 0
-    sil_pdf_ids: list
-    noise_average_decibel = -100.0
-    pre_end_silence_detected = False
-    next_seg = True # unused
-    
-    output_data_buf = []
-    output_data_buf_offset = 0
-    frame_probs = [] # unused
-    max_end_sil_frame_cnt_thresh: int
-    speech_noise_thres: float
-    scores = None
-    max_time_out = False #unused
-    decibel = []
-    data_buf = None
-    data_buf_all = None
-    waveform = None
-    last_drop_frames = 0
-    
 @tables.register("model_classes", "FsmnVADStreaming")
 class FsmnVADStreaming(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self,
-                 encoder: str = None,
-                 encoder_conf: Optional[Dict] = None,
-                 vad_post_args: Dict[str, Any] = None,
-                 **kwargs,
-                 ):
-        super().__init__()
-        self.vad_opts = VADXOptions(**kwargs)
-
-        encoder_class = tables.encoder_classes.get(encoder)
-        encoder = encoder_class(**encoder_conf)
-        self.encoder = encoder
-
-
-    def ResetDetection(self, cache: dict = {}):
-        cache["stats"].continous_silence_frame_count = 0
-        cache["stats"].latest_confirmed_speech_frame = 0
-        cache["stats"].lastest_confirmed_silence_frame = -1
-        cache["stats"].confirmed_start_frame = -1
-        cache["stats"].confirmed_end_frame = -1
-        cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
-        cache["windows_detector"].Reset()
-        cache["stats"].sil_frame = 0
-        cache["stats"].frame_probs = []
-
-        if cache["stats"].output_data_buf:
-            assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
-            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
-            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
-            cache["stats"].last_drop_frames = drop_frames
-            cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
-            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
-
-    def ComputeDecibel(self, cache: dict = {}) -> None:
-        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
-        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-        if cache["stats"].data_buf_all is None:
-            cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
-            cache["stats"].data_buf = cache["stats"].data_buf_all
-        else:
-            cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
-        for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
-            cache["stats"].decibel.append(
-                10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
-                                0.000001))
-
-    def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
-        scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
-        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
-        self.vad_opts.nn_eval_block_size = scores.shape[1]
-        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
-        if cache["stats"].scores is None:
-            cache["stats"].scores = scores  # the first calculation
-        else:
-            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
-
-    def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
-        while cache["stats"].data_buf_start_frame < frame_idx:
-            if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
-                cache["stats"].data_buf_start_frame += 1
-                cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
-                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-
-    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
-                           last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
-        self.PopDataBufTillFrame(start_frm, cache=cache)
-        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
-        if last_frm_is_end_point:
-            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
-                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
-            expected_sample_number += int(extra_sample)
-        if end_point_is_sent_end:
-            expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
-        if len(cache["stats"].data_buf) < expected_sample_number:
-            print('error in calling pop data_buf\n')
-
-        if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
-            cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
-            cache["stats"].output_data_buf[-1].Reset()
-            cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
-            cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
-            cache["stats"].output_data_buf[-1].doa = 0
-        cur_seg = cache["stats"].output_data_buf[-1]
-        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-            print('warning\n')
-        out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
-        data_to_pop = 0
-        if end_point_is_sent_end:
-            data_to_pop = expected_sample_number
-        else:
-            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-        if data_to_pop > len(cache["stats"].data_buf):
-            print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
-            data_to_pop = len(cache["stats"].data_buf)
-            expected_sample_number = len(cache["stats"].data_buf)
-
-        cur_seg.doa = 0
-        for sample_cpy_out in range(0, data_to_pop):
-            # cur_seg.buffer[out_pos ++] = data_buf_.back();
-            out_pos += 1
-        for sample_cpy_out in range(data_to_pop, expected_sample_number):
-            # cur_seg.buffer[out_pos++] = data_buf_.back()
-            out_pos += 1
-        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-            print('Something wrong with the VAD algorithm\n')
-        cache["stats"].data_buf_start_frame += frm_cnt
-        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
-        if first_frm_is_start_point:
-            cur_seg.contain_seg_start_point = True
-        if last_frm_is_end_point:
-            cur_seg.contain_seg_end_point = True
-
-    def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
-        cache["stats"].lastest_confirmed_silence_frame = valid_frame
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-            self.PopDataBufTillFrame(valid_frame, cache=cache)
-        # silence_detected_callback_
-        # pass
-
-    def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
-        cache["stats"].latest_confirmed_speech_frame = valid_frame
-        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
-
-    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
-        if self.vad_opts.do_start_point_detection:
-            pass
-        if cache["stats"].confirmed_start_frame != -1:
-            print('not reset vad properly\n')
-        else:
-            cache["stats"].confirmed_start_frame = start_frame
-
-        if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-            self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
-
-    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
-        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
-            self.OnVoiceDetected(t, cache=cache)
-        if self.vad_opts.do_end_point_detection:
-            pass
-        if cache["stats"].confirmed_end_frame != -1:
-            print('not reset vad properly\n')
-        else:
-            cache["stats"].confirmed_end_frame = end_frame
-        if not fake_result:
-            cache["stats"].sil_frame = 0
-            self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
-        cache["stats"].number_end_time_detected += 1
-
-    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
-        if is_final_frame:
-            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
-            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-
-    def GetLatency(self, cache: dict = {}) -> int:
-        return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
-
-    def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
-        vad_latency = cache["windows_detector"].GetWinSize()
-        if self.vad_opts.do_extend:
-            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
-        return vad_latency
-
-    def GetFrameState(self, t: int, cache: dict = {}):
-        frame_state = FrameState.kFrameStateInvalid
-        cur_decibel = cache["stats"].decibel[t]
-        cur_snr = cur_decibel - cache["stats"].noise_average_decibel
-        # for each frame, calc log posterior probability of each state
-        if cur_decibel < self.vad_opts.decibel_thres:
-            frame_state = FrameState.kFrameStateSil
-            self.DetectOneFrame(frame_state, t, False, cache=cache)
-            return frame_state
-
-        sum_score = 0.0
-        noise_prob = 0.0
-        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
-        if len(cache["stats"].sil_pdf_ids) > 0:
-            assert len(cache["stats"].scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
-            sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
-            sum_score = sum(sil_pdf_scores)
-            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
-            total_score = 1.0
-            sum_score = total_score - sum_score
-        speech_prob = math.log(sum_score)
-        if self.vad_opts.output_frame_probs:
-            frame_prob = E2EVadFrameProb()
-            frame_prob.noise_prob = noise_prob
-            frame_prob.speech_prob = speech_prob
-            frame_prob.score = sum_score
-            frame_prob.frame_id = t
-            cache["stats"].frame_probs.append(frame_prob)
-        if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
-            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
-                frame_state = FrameState.kFrameStateSpeech
-            else:
-                frame_state = FrameState.kFrameStateSil
-        else:
-            frame_state = FrameState.kFrameStateSil
-            if cache["stats"].noise_average_decibel < -99.9:
-                cache["stats"].noise_average_decibel = cur_decibel
-            else:
-                cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
-                        self.vad_opts.noise_frame_num_used_for_snr
-                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
-
-        return frame_state
-
-    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
-                is_final: bool = False
-                ):
-        # if len(cache) == 0:
-        #     self.AllResetDetection()
-        # self.waveform = waveform  # compute decibel for each frame
-        cache["stats"].waveform = waveform
-        self.ComputeDecibel(cache=cache)
-        self.ComputeScores(feats, cache=cache)
-        if not is_final:
-            self.DetectCommonFrames(cache=cache)
-        else:
-            self.DetectLastFrames(cache=cache)
-        segments = []
-        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
-            segment_batch = []
-            if len(cache["stats"].output_data_buf) > 0:
-                for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
-                    if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
-                        i].contain_seg_end_point):
-                        continue
-                    segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
-                    segment_batch.append(segment)
-                    cache["stats"].output_data_buf_offset += 1  # need update this parameter
-            if segment_batch:
-                segments.append(segment_batch)
-        # if is_final:
-        #     # reset class variables and clear the dict for the next query
-        #     self.AllResetDetection()
-        return segments
-
-    def init_cache(self, cache: dict = {}, **kwargs):
-        cache["frontend"] = {}
-        cache["prev_samples"] = torch.empty(0)
-        cache["encoder"] = {}
-        windows_detector = WindowDetector(self.vad_opts.window_size_ms,
-                                          self.vad_opts.sil_to_speech_time_thres,
-                                          self.vad_opts.speech_to_sil_time_thres,
-                                          self.vad_opts.frame_in_ms)
-
-        stats = StatsItem(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
-                          max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
-                          speech_noise_thres=self.vad_opts.speech_noise_thres,
-                      )
-        cache["windows_detector"] = windows_detector
-        cache["stats"] = stats
-        return cache
-    
-    def inference(self,
-                 data_in,
-                 data_lengths=None,
-                 key: list = None,
-                 tokenizer=None,
-                 frontend=None,
-                 cache: dict = {},
-                 **kwargs,
-                 ):
-    
-        if len(cache) == 0:
-            self.init_cache(cache, **kwargs)
-
-        meta_data = {}
-        chunk_size = kwargs.get("chunk_size", 60000) # 50ms
-        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
-
-        time1 = time.perf_counter()
-        cfg = {"is_final": kwargs.get("is_final", False)}
-        audio_sample_list = load_audio_text_image_video(data_in,
-                                                        fs=frontend.fs,
-                                                        audio_fs=kwargs.get("fs", 16000),
-                                                        data_type=kwargs.get("data_type", "sound"),
-                                                        tokenizer=tokenizer,
-                                                        cache=cfg,
-                                                        )
-        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
-
-        time2 = time.perf_counter()
-        meta_data["load_data"] = f"{time2 - time1:0.3f}"
-        assert len(audio_sample_list) == 1, "batch_size must be set 1"
-
-        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
-
-        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
-        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
-        segments = []
-        for i in range(n):
-            kwargs["is_final"] = _is_final and i == n - 1
-            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
-    
-            # extract fbank feats
-            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend, cache=cache["frontend"],
-                                                   is_final=kwargs["is_final"])
-            time3 = time.perf_counter()
-            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-            speech = speech.to(device=kwargs["device"])
-            speech_lengths = speech_lengths.to(device=kwargs["device"])
-            
-            batch = {
-                "feats": speech,
-                "waveform": cache["frontend"]["waveforms"],
-                "is_final": kwargs["is_final"],
-                "cache": cache
-            }
-            segments_i = self.forward(**batch)
-            if len(segments_i) > 0:
-                segments.extend(*segments_i)
-
-
-        cache["prev_samples"] = audio_sample[:-m]
-        if _is_final:
-            self.init_cache(cache, **kwargs)
-
-        ibest_writer = None
-        if ibest_writer is None and kwargs.get("output_dir") is not None:
-            writer = DatadirWriter(kwargs.get("output_dir"))
-            ibest_writer = writer[f"{1}best_recog"]
-
-        results = []
-        result_i = {"key": key[0], "value": segments}
-        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-            result_i = json.dumps(result_i)
-
-        results.append(result_i)
-            
-        if ibest_writer is not None:
-            ibest_writer["text"][key[0]] = segments
-
- 
-        return results, meta_data
-
-
-    def DetectCommonFrames(self, cache: dict = {}) -> int:
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
-            return 0
-        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
-            frame_state = FrameState.kFrameStateInvalid
-            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
-            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
-
-        return 0
-
-    def DetectLastFrames(self, cache: dict = {}) -> int:
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
-            return 0
-        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
-            frame_state = FrameState.kFrameStateInvalid
-            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
-            if i != 0:
-                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
-            else:
-                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
-
-        return 0
-
-    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
-        tmp_cur_frm_state = FrameState.kFrameStateInvalid
-        if cur_frm_state == FrameState.kFrameStateSpeech:
-            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
-                tmp_cur_frm_state = FrameState.kFrameStateSpeech
-            else:
-                tmp_cur_frm_state = FrameState.kFrameStateSil
-        elif cur_frm_state == FrameState.kFrameStateSil:
-            tmp_cur_frm_state = FrameState.kFrameStateSil
-        state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
-        frm_shift_in_ms = self.vad_opts.frame_in_ms
-        if AudioChangeState.kChangeStateSil2Speech == state_change:
-            silence_frame_count = cache["stats"].continous_silence_frame_count
-            cache["stats"].continous_silence_frame_count = 0
-            cache["stats"].pre_end_silence_detected = False
-            start_frame = 0
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-                start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
-                self.OnVoiceStart(start_frame, cache=cache)
-                cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
-                for t in range(start_frame + 1, cur_frm_idx + 1):
-                    self.OnVoiceDetected(t, cache=cache)
-            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
-                    self.OnVoiceDetected(t, cache=cache)
-                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif not is_final_frame:
-                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
-            cache["stats"].continous_silence_frame_count = 0
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-                pass
-            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif not is_final_frame:
-                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
-            cache["stats"].continous_silence_frame_count = 0
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    cache["stats"].max_time_out = True
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif not is_final_frame:
-                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-        elif AudioChangeState.kChangeStateSil2Sil == state_change:
-            cache["stats"].continous_silence_frame_count += 1
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-                # silence timeout, return zero length decision
-                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
-                        cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
-                        or (is_final_frame and cache["stats"].number_end_time_detected == 0):
-                    for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
-                        self.OnSilenceDetected(t, cache=cache)
-                    self.OnVoiceStart(0, True, cache=cache)
-                    self.OnVoiceEnd(0, True, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                else:
-                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
-                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
-            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
-                    lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
-                    if self.vad_opts.do_extend:
-                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
-                        lookback_frame -= 1
-                        lookback_frame = max(0, lookback_frame)
-                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif self.vad_opts.do_extend and not is_final_frame:
-                    if cache["stats"].continous_silence_frame_count <= int(
-                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
-                        self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
-                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
-            self.ResetDetection(cache=cache)
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self,
+	             encoder: str = None,
+	             encoder_conf: Optional[Dict] = None,
+	             vad_post_args: Dict[str, Any] = None,
+	             **kwargs,
+	             ):
+		super().__init__()
+		self.vad_opts = VADXOptions(**kwargs)
+		
+		encoder_class = tables.encoder_classes.get(encoder)
+		encoder = encoder_class(**encoder_conf)
+		self.encoder = encoder
+	
+	
+	def ResetDetection(self, cache: dict = {}):
+		cache["stats"].continous_silence_frame_count = 0
+		cache["stats"].latest_confirmed_speech_frame = 0
+		cache["stats"].lastest_confirmed_silence_frame = -1
+		cache["stats"].confirmed_start_frame = -1
+		cache["stats"].confirmed_end_frame = -1
+		cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+		cache["windows_detector"].Reset()
+		cache["stats"].sil_frame = 0
+		cache["stats"].frame_probs = []
+		
+		if cache["stats"].output_data_buf:
+			assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
+			drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+			real_drop_frames = drop_frames - cache["stats"].last_drop_frames
+			cache["stats"].last_drop_frames = drop_frames
+			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+			cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
+			cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
+	
+	def ComputeDecibel(self, cache: dict = {}) -> None:
+		frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+		frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+		if cache["stats"].data_buf_all is None:
+			cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
+			cache["stats"].data_buf = cache["stats"].data_buf_all
+		else:
+			cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
+		for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
+			cache["stats"].decibel.append(
+				10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
+				                0.000001))
+	
+	def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
+		scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
+		assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
+		self.vad_opts.nn_eval_block_size = scores.shape[1]
+		cache["stats"].frm_cnt += scores.shape[1]  # count total frames
+		if cache["stats"].scores is None:
+			cache["stats"].scores = scores  # the first calculation
+		else:
+			cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
+	
+	def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
+		while cache["stats"].data_buf_start_frame < frame_idx:
+			if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+				cache["stats"].data_buf_start_frame += 1
+				cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
+					self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+	
+	def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
+	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
+		self.PopDataBufTillFrame(start_frm, cache=cache)
+		expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
+		if last_frm_is_end_point:
+			extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
+			                          self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+			expected_sample_number += int(extra_sample)
+		if end_point_is_sent_end:
+			expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
+		if len(cache["stats"].data_buf) < expected_sample_number:
+			print('error in calling pop data_buf\n')
+		
+		if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
+			cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
+			cache["stats"].output_data_buf[-1].Reset()
+			cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+			cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
+			cache["stats"].output_data_buf[-1].doa = 0
+		cur_seg = cache["stats"].output_data_buf[-1]
+		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+			print('warning\n')
+		out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+		data_to_pop = 0
+		if end_point_is_sent_end:
+			data_to_pop = expected_sample_number
+		else:
+			data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+		if data_to_pop > len(cache["stats"].data_buf):
+			print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
+			data_to_pop = len(cache["stats"].data_buf)
+			expected_sample_number = len(cache["stats"].data_buf)
+		
+		cur_seg.doa = 0
+		for sample_cpy_out in range(0, data_to_pop):
+			# cur_seg.buffer[out_pos ++] = data_buf_.back();
+			out_pos += 1
+		for sample_cpy_out in range(data_to_pop, expected_sample_number):
+			# cur_seg.buffer[out_pos++] = data_buf_.back()
+			out_pos += 1
+		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+			print('Something wrong with the VAD algorithm\n')
+		cache["stats"].data_buf_start_frame += frm_cnt
+		cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+		if first_frm_is_start_point:
+			cur_seg.contain_seg_start_point = True
+		if last_frm_is_end_point:
+			cur_seg.contain_seg_end_point = True
+	
+	def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
+		cache["stats"].lastest_confirmed_silence_frame = valid_frame
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+			self.PopDataBufTillFrame(valid_frame, cache=cache)
+		# silence_detected_callback_
+		# pass
+	
+	def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
+		cache["stats"].latest_confirmed_speech_frame = valid_frame
+		self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
+	
+	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
+		if self.vad_opts.do_start_point_detection:
+			pass
+		if cache["stats"].confirmed_start_frame != -1:
+			print('not reset vad properly\n')
+		else:
+			cache["stats"].confirmed_start_frame = start_frame
+		
+		if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+			self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
+	
+	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
+		for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
+			self.OnVoiceDetected(t, cache=cache)
+		if self.vad_opts.do_end_point_detection:
+			pass
+		if cache["stats"].confirmed_end_frame != -1:
+			print('not reset vad properly\n')
+		else:
+			cache["stats"].confirmed_end_frame = end_frame
+		if not fake_result:
+			cache["stats"].sil_frame = 0
+			self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
+		cache["stats"].number_end_time_detected += 1
+	
+	def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
+		if is_final_frame:
+			self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
+			cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+	
+	def GetLatency(self, cache: dict = {}) -> int:
+		return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
+	
+	def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
+		vad_latency = cache["windows_detector"].GetWinSize()
+		if self.vad_opts.do_extend:
+			vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+		return vad_latency
+	
+	def GetFrameState(self, t: int, cache: dict = {}):
+		frame_state = FrameState.kFrameStateInvalid
+		cur_decibel = cache["stats"].decibel[t]
+		cur_snr = cur_decibel - cache["stats"].noise_average_decibel
+		# for each frame, calc log posterior probability of each state
+		if cur_decibel < self.vad_opts.decibel_thres:
+			frame_state = FrameState.kFrameStateSil
+			self.DetectOneFrame(frame_state, t, False, cache=cache)
+			return frame_state
+		
+		sum_score = 0.0
+		noise_prob = 0.0
+		assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
+		if len(cache["stats"].sil_pdf_ids) > 0:
+			assert len(cache["stats"].scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+			sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
+			sum_score = sum(sil_pdf_scores)
+			noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+			total_score = 1.0
+			sum_score = total_score - sum_score
+		speech_prob = math.log(sum_score)
+		if self.vad_opts.output_frame_probs:
+			frame_prob = E2EVadFrameProb()
+			frame_prob.noise_prob = noise_prob
+			frame_prob.speech_prob = speech_prob
+			frame_prob.score = sum_score
+			frame_prob.frame_id = t
+			cache["stats"].frame_probs.append(frame_prob)
+		if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
+			if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+				frame_state = FrameState.kFrameStateSpeech
+			else:
+				frame_state = FrameState.kFrameStateSil
+		else:
+			frame_state = FrameState.kFrameStateSil
+			if cache["stats"].noise_average_decibel < -99.9:
+				cache["stats"].noise_average_decibel = cur_decibel
+			else:
+				cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
+					self.vad_opts.noise_frame_num_used_for_snr
+					- 1)) / self.vad_opts.noise_frame_num_used_for_snr
+		
+		return frame_state
+	
+	def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
+	            is_final: bool = False
+	            ):
+		# if len(cache) == 0:
+		#     self.AllResetDetection()
+		# self.waveform = waveform  # compute decibel for each frame
+		cache["stats"].waveform = waveform
+		self.ComputeDecibel(cache=cache)
+		self.ComputeScores(feats, cache=cache)
+		if not is_final:
+			self.DetectCommonFrames(cache=cache)
+		else:
+			self.DetectLastFrames(cache=cache)
+		segments = []
+		for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
+			segment_batch = []
+			if len(cache["stats"].output_data_buf) > 0:
+				for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
+					if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
+						i].contain_seg_end_point):
+						continue
+					segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
+					segment_batch.append(segment)
+					cache["stats"].output_data_buf_offset += 1  # need update this parameter
+			if segment_batch:
+				segments.append(segment_batch)
+		# if is_final:
+		#     # reset class variables and clear the dict for the next query
+		#     self.AllResetDetection()
+		return segments
+	
+	def init_cache(self, cache: dict = {}, **kwargs):
+		cache["frontend"] = {}
+		cache["prev_samples"] = torch.empty(0)
+		cache["encoder"] = {}
+		windows_detector = WindowDetector(self.vad_opts.window_size_ms,
+		                                  self.vad_opts.sil_to_speech_time_thres,
+		                                  self.vad_opts.speech_to_sil_time_thres,
+		                                  self.vad_opts.frame_in_ms)
+		windows_detector.Reset()
+		
+		stats = Stats(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
+		              max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
+		              speech_noise_thres=self.vad_opts.speech_noise_thres
+		              )
+		cache["windows_detector"] = windows_detector
+		cache["stats"] = stats
+		return cache
+	
+	def inference(self,
+	              data_in,
+	              data_lengths=None,
+	              key: list = None,
+	              tokenizer=None,
+	              frontend=None,
+	              cache: dict = {},
+	              **kwargs,
+	              ):
+		
+		if len(cache) == 0:
+			self.init_cache(cache, **kwargs)
+		
+		meta_data = {}
+		chunk_size = kwargs.get("chunk_size", 60000) # 50ms
+		chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
+		
+		time1 = time.perf_counter()
+		cfg = {"is_final": kwargs.get("is_final", False)}
+		audio_sample_list = load_audio_text_image_video(data_in,
+		                                                fs=frontend.fs,
+		                                                audio_fs=kwargs.get("fs", 16000),
+		                                                data_type=kwargs.get("data_type", "sound"),
+		                                                tokenizer=tokenizer,
+		                                                cache=cfg,
+		                                                )
+		_is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
+		
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		assert len(audio_sample_list) == 1, "batch_size must be set 1"
+		
+		audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+		
+		n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+		m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
+		segments = []
+		for i in range(n):
+			kwargs["is_final"] = _is_final and i == n - 1
+			audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
+			
+			# extract fbank feats
+			speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+			                                       frontend=frontend, cache=cache["frontend"],
+			                                       is_final=kwargs["is_final"])
+			time3 = time.perf_counter()
+			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+			speech = speech.to(device=kwargs["device"])
+			speech_lengths = speech_lengths.to(device=kwargs["device"])
+			
+			batch = {
+				"feats": speech,
+				"waveform": cache["frontend"]["waveforms"],
+				"is_final": kwargs["is_final"],
+				"cache": cache
+			}
+			segments_i = self.forward(**batch)
+			if len(segments_i) > 0:
+				segments.extend(*segments_i)
+		
+		
+		cache["prev_samples"] = audio_sample[:-m]
+		if _is_final:
+			cache = {}
+		
+		ibest_writer = None
+		if ibest_writer is None and kwargs.get("output_dir") is not None:
+			writer = DatadirWriter(kwargs.get("output_dir"))
+			ibest_writer = writer[f"{1}best_recog"]
+		
+		results = []
+		result_i = {"key": key[0], "value": segments}
+		if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+			result_i = json.dumps(result_i)
+		
+		results.append(result_i)
+		
+		if ibest_writer is not None:
+			ibest_writer["text"][key[0]] = segments
+		
+		
+		return results, meta_data
+	
+	
+	def DetectCommonFrames(self, cache: dict = {}) -> int:
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+			return 0
+		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+			frame_state = FrameState.kFrameStateInvalid
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
+		
+		return 0
+	
+	def DetectLastFrames(self, cache: dict = {}) -> int:
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+			return 0
+		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+			frame_state = FrameState.kFrameStateInvalid
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			if i != 0:
+				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
+			else:
+				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
+		
+		return 0
+	
+	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
+		tmp_cur_frm_state = FrameState.kFrameStateInvalid
+		if cur_frm_state == FrameState.kFrameStateSpeech:
+			if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+				tmp_cur_frm_state = FrameState.kFrameStateSpeech
+			else:
+				tmp_cur_frm_state = FrameState.kFrameStateSil
+		elif cur_frm_state == FrameState.kFrameStateSil:
+			tmp_cur_frm_state = FrameState.kFrameStateSil
+		state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
+		frm_shift_in_ms = self.vad_opts.frame_in_ms
+		if AudioChangeState.kChangeStateSil2Speech == state_change:
+			silence_frame_count = cache["stats"].continous_silence_frame_count
+			cache["stats"].continous_silence_frame_count = 0
+			cache["stats"].pre_end_silence_detected = False
+			start_frame = 0
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+				start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
+				self.OnVoiceStart(start_frame, cache=cache)
+				cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+				for t in range(start_frame + 1, cur_frm_idx + 1):
+					self.OnVoiceDetected(t, cache=cache)
+			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
+					self.OnVoiceDetected(t, cache=cache)
+				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif not is_final_frame:
+					self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+			cache["stats"].continous_silence_frame_count = 0
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+				pass
+			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif not is_final_frame:
+					self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+			cache["stats"].continous_silence_frame_count = 0
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					cache["stats"].max_time_out = True
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif not is_final_frame:
+					self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		elif AudioChangeState.kChangeStateSil2Sil == state_change:
+			cache["stats"].continous_silence_frame_count += 1
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+				# silence timeout, return zero length decision
+				if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
+					cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+					or (is_final_frame and cache["stats"].number_end_time_detected == 0):
+					for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
+						self.OnSilenceDetected(t, cache=cache)
+					self.OnVoiceStart(0, True, cache=cache)
+					self.OnVoiceEnd(0, True, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				else:
+					if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
+						self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
+			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
+					lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
+					if self.vad_opts.do_extend:
+						lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
+						lookback_frame -= 1
+						lookback_frame = max(0, lookback_frame)
+					self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif self.vad_opts.do_extend and not is_final_frame:
+					if cache["stats"].continous_silence_frame_count <= int(
+						self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
+						self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
+			self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
+			self.ResetDetection(cache=cache)
 
 
 
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 9f3c3f3..468d23f 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -33,7 +33,6 @@
     
     def __init__(
         self,
-        # token_list: Union[Tuple[str, ...], List[str]],
         specaug: Optional[str] = None,
         specaug_conf: Optional[Dict] = None,
         normalize: str = None,
diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml
index 3972caa..bccf638 100644
--- a/funasr/models/paraformer/template.yaml
+++ b/funasr/models/paraformer/template.yaml
@@ -6,7 +6,6 @@
 # tables.print()
 
 # network architecture
-#model: funasr.models.paraformer.model:Paraformer
 model: Paraformer
 model_conf:
     ctc_weight: 0.0
@@ -87,13 +86,6 @@
   accum_grad: 1
   grad_clip: 5
   max_epoch: 150
-  val_scheduler_criterion:
-      - valid
-      - acc
-  best_model_criterion:
-  -   - valid
-      - acc
-      - max
   keep_nbest_models: 10
   avg_nbest_model: 5
   log_interval: 50
diff --git a/funasr/models/sanm/decoder.py b/funasr/models/sanm/decoder.py
index 190ada0..3575282 100644
--- a/funasr/models/sanm/decoder.py
+++ b/funasr/models/sanm/decoder.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 from typing import List
 from typing import Tuple
 import logging
@@ -193,10 +198,9 @@
 @tables.register("decoder_classes", "FsmnDecoder")
 class FsmnDecoder(BaseTransformerDecoder):
     """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
+    San-m: Memory equipped self-attention for end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
-
     """
     
     def __init__(
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index cb4e21a..069c527 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 from typing import List
 from typing import Optional
 from typing import Sequence
@@ -156,10 +161,9 @@
 @tables.register("encoder_classes", "SANMEncoder")
 class SANMEncoder(nn.Module):
     """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
     San-m: Memory equipped self-attention for end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
-
     """
 
     def __init__(
diff --git a/funasr/models/sanm/model.py b/funasr/models/sanm/model.py
index 4dc8825..0cef540 100644
--- a/funasr/models/sanm/model.py
+++ b/funasr/models/sanm/model.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import logging
 
 import torch
@@ -7,7 +12,11 @@
 
 @tables.register("model_classes", "SANM")
 class SANM(Transformer):
-    """CTC-attention hybrid Encoder-Decoder model"""
+    """
+    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
+    San-m: Memory equipped self-attention for end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+    """
 
     def __init__(
         self,
diff --git a/funasr/models/sanm/template.yaml b/funasr/models/sanm/template.yaml
new file mode 100644
index 0000000..156926f
--- /dev/null
+++ b/funasr/models/sanm/template.yaml
@@ -0,0 +1,121 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: SANM
+model_conf:
+    ctc_weight: 0.0
+    lsm_weight: 0.1
+    length_normalized_loss: true
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+    output_size: 512
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 50
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+
+# decoder
+decoder: FsmnDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 16
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.1
+    src_attention_dropout_rate: 0.1
+    att_layer_num: 16
+    kernel_size: 11
+    sanm_shfit: 0
+
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  val_scheduler_criterion:
+      - valid
+      - acc
+  best_model_criterion:
+  -   - valid
+      - acc
+      - max
+  keep_nbest_models: 10
+  avg_nbest_model: 5
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+
+normalize: null
diff --git a/funasr/models/scama/sanm_decoder.py b/funasr/models/scama/decoder.py
similarity index 99%
rename from funasr/models/scama/sanm_decoder.py
rename to funasr/models/scama/decoder.py
index 4222e5f..9dcb9da 100644
--- a/funasr/models/scama/sanm_decoder.py
+++ b/funasr/models/scama/decoder.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 from typing import List
 from typing import Tuple
 import logging
@@ -192,11 +197,11 @@
 @tables.register("decoder_classes", "FsmnDecoderSCAMAOpt")
 class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
     """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-
+    https://arxiv.org/abs/2006.01712
     """
+    
     def __init__(
             self,
             vocab_size: int,
diff --git a/funasr/models/scama/sanm_encoder.py b/funasr/models/scama/encoder.py
similarity index 98%
rename from funasr/models/scama/sanm_encoder.py
rename to funasr/models/scama/encoder.py
index 5e28db7..3651e61 100644
--- a/funasr/models/scama/sanm_encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 from typing import List
 from typing import Optional
 from typing import Sequence
@@ -157,10 +162,9 @@
 @tables.register("encoder_classes", "SANMEncoderChunkOpt")
 class SANMEncoderChunkOpt(nn.Module):
     """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-
+    https://arxiv.org/abs/2006.01712
     """
 
     def __init__(
diff --git a/funasr/models/scama/model.py b/funasr/models/scama/model.py
new file mode 100644
index 0000000..aec6fe3
--- /dev/null
+++ b/funasr/models/scama/model.py
@@ -0,0 +1,669 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import torch.nn as nn
+import torch.functional as F
+import logging
+from typing import Dict, Tuple
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.model import Paraformer
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.models.scama.utils import sequence_mask
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+@tables.register("model_classes", "SCAMA")
+class SCAMA(nn.Module):
+    """
+    Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01712
+    """
+
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        decoder: str = None,
+        decoder_conf: dict = None,
+        ctc: str = None,
+        ctc_conf: dict = None,
+        ctc_weight: float = 0.5,
+        predictor: str = None,
+        predictor_conf: dict = None,
+        predictor_bias: int = 0,
+        predictor_weight: float = 0.0,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        share_embedding: bool = False,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+            
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+            
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        decoder_class = tables.decoder_classes.get(decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            **decoder_conf,
+        )
+        if ctc_weight > 0.0:
+    
+            if ctc_conf is None:
+                ctc_conf = {}
+    
+            ctc = CTC(
+                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+            )
+
+        predictor_class = tables.predictor_classes.get(predictor)
+        predictor = predictor_class(**predictor_conf)
+
+        # note that eos is the same as sos (equivalent ID)
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+        
+        self.specaug = specaug
+        self.normalize = normalize
+        
+        self.encoder = encoder
+
+
+        if ctc_weight == 1.0:
+            self.decoder = None
+        else:
+            self.decoder = decoder
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        if ctc_weight == 0.0:
+            self.ctc = None
+        else:
+            self.ctc = ctc
+            
+        self.predictor = predictor
+        self.predictor_weight = predictor_weight
+        self.predictor_bias = predictor_bias
+
+        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+
+        self.share_embedding = share_embedding
+        if self.share_embedding:
+            self.decoder.embed = None
+
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+        self.error_calculator = None
+        
+        if self.encoder.overlap_chunk_cls is not None:
+            from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
+            self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+            self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+
+        decoding_ind = kwargs.get("decoding_ind")
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+    
+        batch_size = speech.shape[0]
+    
+        # Encoder
+        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+
+    
+        loss_ctc, cer_ctc = None, None
+        loss_pre = None
+        stats = dict()
+    
+        # decoder: CTC branch
+    
+        if self.ctc_weight > 0.0:
+
+            encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+                                                                                                encoder_out_lens,
+                                                                                                chunk_outs=None)
+
+        
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+            )
+            # Collect CTC branch stats
+            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+            stats["cer_ctc"] = cer_ctc
+    
+        # decoder: Attention decoder branch
+        loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+    
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss = loss_att + loss_pre * self.predictor_weight
+        else:
+            loss = self.ctc_weight * loss_ctc + (
+                1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+    
+        # Collect Attn branch stats
+        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["acc"] = acc_att
+        stats["cer"] = cer_att
+        stats["wer"] = wer_att
+        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+    
+        stats["loss"] = torch.clone(loss.detach())
+    
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = (text_lengths + self.predictor_bias).sum()
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+        
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+        
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+    
+        # Forward encoder
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+    
+        return encoder_out, encoder_out_lens
+
+    def encode_chunk(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+        
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+        
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+    
+        # Forward encoder
+        encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+    
+        return encoder_out, torch.tensor([encoder_out.size(1)])
+
+    def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
+        is_final = kwargs.get("is_final", False)
+
+        return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
+
+    def _calc_att_predictor_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+
+        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
+                                         device=encoder_out.device)[:, None, :]
+        mask_chunk_predictor = None
+        if self.encoder.overlap_chunk_cls is not None:
+            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+                                                                                           device=encoder_out.device,
+                                                                                           batch_size=encoder_out.size(
+                                                                                               0))
+            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+                                                                                   batch_size=encoder_out.size(0))
+            encoder_out = encoder_out * mask_shfit_chunk
+        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+                                                                              ys_out_pad,
+                                                                              encoder_out_mask,
+                                                                              ignore_id=self.ignore_id,
+                                                                              mask_chunk_predictor=mask_chunk_predictor,
+                                                                              target_label_length=ys_in_lens,
+                                                                              )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+                                                                                             encoder_out_lens)
+
+
+        encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+        attention_chunk_center_bias = 0
+        attention_chunk_size = encoder_chunk_size
+        decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+        mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
+                                                                                                       device=encoder_out.device,
+                                                                                                       batch_size=encoder_out.size(
+                                                                                                           0))
+        scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+            predictor_alignments=predictor_alignments,
+            encoder_sequence_length=encoder_out_lens,
+            chunk_size=1,
+            encoder_chunk_size=encoder_chunk_size,
+            attention_chunk_center_bias=attention_chunk_center_bias,
+            attention_chunk_size=attention_chunk_size,
+            attention_chunk_type=self.decoder_attention_chunk_type,
+            step=None,
+            predictor_mask_chunk_hopping=mask_chunk_predictor,
+            decoder_att_look_back_factor=decoder_att_look_back_factor,
+            mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+            target_length=ys_in_lens,
+            is_training=self.training,
+        )
+
+
+        # try:
+        # 1. Forward decoder
+        decoder_out, _ = self.decoder(
+            encoder_out,
+            encoder_out_lens,
+            ys_in_pad,
+            ys_in_lens,
+            chunk_mask=scama_mask,
+            pre_acoustic_embeds=pre_acoustic_embeds,
+
+        )
+
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_out_pad)
+        acc_att = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_out_pad,
+            ignore_label=self.ignore_id,
+        )
+        # predictor loss
+        loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+        return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+    def calc_predictor_mask(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor = None,
+        ys_pad_lens: torch.Tensor = None,
+    ):
+        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        # ys_in_lens = ys_pad_lens + 1
+        ys_out_pad, ys_in_lens = None, None
+
+        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
+                                         device=encoder_out.device)[:, None, :]
+        mask_chunk_predictor = None
+
+        mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+                                                                                       device=encoder_out.device,
+                                                                                       batch_size=encoder_out.size(
+                                                                                           0))
+        mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+                                                                               batch_size=encoder_out.size(0))
+        encoder_out = encoder_out * mask_shfit_chunk
+        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+                                                                              ys_out_pad,
+                                                                              encoder_out_mask,
+                                                                              ignore_id=self.ignore_id,
+                                                                              mask_chunk_predictor=mask_chunk_predictor,
+                                                                              target_label_length=ys_in_lens,
+                                                                              )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+                                                                                             encoder_out_lens)
+    
+
+        encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+        attention_chunk_center_bias = 0
+        attention_chunk_size = encoder_chunk_size
+        decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+        mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
+                                                                                                       device=encoder_out.device,
+                                                                                                       batch_size=encoder_out.size(
+                                                                                                           0))
+        scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+            predictor_alignments=predictor_alignments,
+            encoder_sequence_length=encoder_out_lens,
+            chunk_size=1,
+            encoder_chunk_size=encoder_chunk_size,
+            attention_chunk_center_bias=attention_chunk_center_bias,
+            attention_chunk_size=attention_chunk_size,
+            attention_chunk_type=self.decoder_attention_chunk_type,
+            step=None,
+            predictor_mask_chunk_hopping=mask_chunk_predictor,
+            decoder_att_look_back_factor=decoder_att_look_back_factor,
+            mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+            target_length=ys_in_lens,
+            is_training=self.training,
+        )
+
+        return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
+
+    def init_beam_search(self,
+                         **kwargs,
+                         ):
+        from funasr.models.scama.beam_search import BeamSearchScama
+        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+    
+        # 1. Build ASR model
+        scorers = {}
+    
+        if self.ctc != None:
+            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+            scorers.update(
+                ctc=ctc
+            )
+        token_list = kwargs.get("token_list")
+        scorers.update(
+            decoder=self.decoder,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+    
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+    
+        weights = dict(
+            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+            ctc=kwargs.get("decoding_ctc_weight", 0.0),
+            lm=kwargs.get("lm_weight", 0.0),
+            ngram=kwargs.get("ngram_weight", 0.0),
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearchScama(
+            beam_size=kwargs.get("beam_size", 2),
+            weights=weights,
+            scorers=scorers,
+            sos=self.sos,
+            eos=self.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+        )
+        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+        # for scorer in scorers.values():
+        #     if isinstance(scorer, torch.nn.Module):
+        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+        self.beam_search = beam_search
+
+    def generate_chunk(self,
+                       speech,
+                       speech_lengths=None,
+                       key: list = None,
+                       tokenizer=None,
+                       frontend=None,
+                       **kwargs,
+                       ):
+        cache = kwargs.get("cache", {})
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+    
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache,
+                                                          is_final=kwargs.get("is_final", False))
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        # predictor
+        predictor_outs = self.calc_predictor_chunk(encoder_out,
+                                                   encoder_out_lens,
+                                                   cache=cache,
+                                                   is_final=kwargs.get("is_final", False),
+                                                   )
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+                                                                        predictor_outs[2], predictor_outs[3]
+        pre_token_length = pre_token_length.round().long()
+
+
+        if torch.max(pre_token_length) < 1:
+            return []
+        decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
+                                                             encoder_out_lens,
+                                                             pre_acoustic_embeds,
+                                                             pre_token_length,
+                                                             cache=cache
+                                                             )
+        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+    
+        results = []
+        b, n, d = decoder_out.size()
+        if isinstance(key[0], (list, tuple)):
+            key = key[0]
+        for i in range(b):
+            x = encoder_out[i, :encoder_out_lens[i], :]
+            am_scores = decoder_out[i, :pre_token_length[i], :]
+            if self.beam_search is not None:
+                nbest_hyps = self.beam_search(
+                    x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+                    minlenratio=kwargs.get("minlenratio", 0.0)
+                )
+            
+                nbest_hyps = nbest_hyps[: self.nbest]
+            else:
+            
+                yseq = am_scores.argmax(dim=-1)
+                score = am_scores.max(dim=-1)[0]
+                score = torch.sum(score, dim=-1)
+                # pad with mask tokens to ensure compatibility with sos/eos tokens
+                yseq = torch.tensor(
+                    [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+                )
+                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+            
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+            
+                # remove blank symbol id, which is assumed to be 0
+                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+            
+                # Change integer-ids to tokens
+                token = tokenizer.ids2tokens(token_int)
+                # text = tokenizer.tokens2text(token)
+            
+                result_i = token
+            
+                results.extend(result_i)
+    
+        return results
+
+    def init_cache(self, cache: dict = {}, **kwargs):
+        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
+        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
+        batch_size = 1
+    
+        enc_output_size = kwargs["encoder_conf"]["output_size"]
+        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
+        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+                         "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
+                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+                         "tail_chunk": False}
+        cache["encoder"] = cache_encoder
+    
+        cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
+                         "chunk_size": chunk_size}
+        cache["decoder"] = cache_decoder
+        cache["frontend"] = {}
+        cache["prev_samples"] = torch.empty(0)
+    
+        return cache
+
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  cache: dict = {},
+                  **kwargs,
+                  ):
+    
+        # init beamsearch
+        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+        if self.beam_search is None and (is_use_lm or is_use_ctc):
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
+    
+        if len(cache) == 0:
+            self.init_cache(cache, **kwargs)
+    
+        meta_data = {}
+        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+        chunk_stride_samples = int(chunk_size[1] * 960)  # 600ms
+    
+        time1 = time.perf_counter()
+        cfg = {"is_final": kwargs.get("is_final", False)}
+        audio_sample_list = load_audio_text_image_video(data_in,
+                                                        fs=frontend.fs,
+                                                        audio_fs=kwargs.get("fs", 16000),
+                                                        data_type=kwargs.get("data_type", "sound"),
+                                                        tokenizer=tokenizer,
+                                                        cache=cfg,
+                                                        )
+        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
+    
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        assert len(audio_sample_list) == 1, "batch_size must be set 1"
+    
+        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+    
+        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
+        tokens = []
+        for i in range(n):
+            kwargs["is_final"] = _is_final and i == n - 1
+            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
+        
+            # extract fbank feats
+            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend, cache=cache["frontend"],
+                                                   is_final=kwargs["is_final"])
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+        
+            tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache,
+                                           frontend=frontend, **kwargs)
+            tokens.extend(tokens_i)
+    
+        text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
+    
+        result_i = {"key": key[0], "text": text_postprocessed}
+        result = [result_i]
+    
+        cache["prev_samples"] = audio_sample[:-m]
+        if _is_final:
+            self.init_cache(cache, **kwargs)
+    
+        if kwargs.get("output_dir"):
+            writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = writer[f"{1}best_recog"]
+            ibest_writer["token"][key[0]] = " ".join(tokens)
+            ibest_writer["text"][key[0]] = text_postprocessed
+    
+        return result, meta_data
diff --git a/funasr/models/scama/template.yaml b/funasr/models/scama/template.yaml
new file mode 100644
index 0000000..f647a92
--- /dev/null
+++ b/funasr/models/scama/template.yaml
@@ -0,0 +1,127 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: SCAMA
+model_conf:
+    ctc_weight: 0.0
+    lsm_weight: 0.1
+    length_normalized_loss: true
+
+# encoder
+encoder: SANMEncoderChunkOpt
+encoder_conf:
+    output_size: 512
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 50
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+
+# decoder
+decoder: FsmnDecoderSCAMAOpt
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 16
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.1
+    src_attention_dropout_rate: 0.1
+    att_layer_num: 16
+    kernel_size: 11
+    sanm_shfit: 0
+
+predictor: CifPredictorV2
+predictor_conf:
+    idim: 512
+    threshold: 1.0
+    l_order: 1
+    r_order: 1
+    tail_threshold: 0.45
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  val_scheduler_criterion:
+      - valid
+      - acc
+  best_model_criterion:
+  -   - valid
+      - acc
+      - max
+  keep_nbest_models: 10
+  avg_nbest_model: 5
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+
+normalize: null
diff --git a/funasr/models/uniasr/e2e_uni_asr.py b/funasr/models/uniasr/model.py
similarity index 95%
rename from funasr/models/uniasr/e2e_uni_asr.py
rename to funasr/models/uniasr/model.py
index 390d274..de80d4a 100644
--- a/funasr/models/uniasr/e2e_uni_asr.py
+++ b/funasr/models/uniasr/model.py
@@ -1,85 +1,73 @@
-import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
+import time
 import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
 
-from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
 from funasr.metrics.compute_acc import th_accuracy
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss,  # noqa: H301
-)
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.frontends.abs_frontend import AbsFrontend
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
-from funasr.models.scama.chunk_utilis import sequence_mask
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
 from funasr.models.paraformer.cif_predictor import mae_loss
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-    from torch.cuda.amp import autocast
-else:
-    # Nothing to do if torch<1.6.0
-    @contextmanager
-    def autocast(enabled=True):
-        yield
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 
-class UniASR(FunASRModel):
+@tables.register("model_classes", "UniASR")
+class UniASR(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     """
 
     def __init__(
         self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
-        encoder: AbsEncoder,
-        decoder: AbsDecoder,
-        ctc: CTC,
+        specaug: Optional[str] = None,
+        specaug_conf: Optional[Dict] = None,
+        normalize: str = None,
+        normalize_conf: Optional[Dict] = None,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        decoder: str = None,
+        decoder_conf: Optional[Dict] = None,
+        ctc: str = None,
+        ctc_conf: Optional[Dict] = None,
+        predictor: str = None,
+        predictor_conf: Optional[Dict] = None,
         ctc_weight: float = 0.5,
-        interctc_weight: float = 0.0,
+        input_size: int = 80,
+        vocab_size: int = -1,
         ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
         lsm_weight: float = 0.0,
         length_normalized_loss: bool = False,
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
-        extract_feats_in_collect_stats: bool = True,
-        predictor=None,
+        # report_cer: bool = True,
+        # report_wer: bool = True,
+        # sym_space: str = "<space>",
+        # sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        # predictor=None,
         predictor_weight: float = 0.0,
-        decoder_attention_chunk_type: str = 'chunk',
-        encoder2: AbsEncoder = None,
-        decoder2: AbsDecoder = None,
-        ctc2: CTC = None,
-        ctc_weight2: float = 0.5,
-        interctc_weight2: float = 0.0,
-        predictor2=None,
-        predictor_weight2: float = 0.0,
-        decoder_attention_chunk_type2: str = 'chunk',
-        stride_conv=None,
-        loss_weight_model1: float = 0.5,
-        enable_maas_finetune: bool = False,
-        freeze_encoder2: bool = False,
-        preencoder: Optional[AbsPreEncoder] = None,
-        postencoder: Optional[AbsPostEncoder] = None,
+        predictor_bias: int = 0,
+        sampling_ratio: float = 0.2,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        use_1st_decoder_loss: bool = False,
         encoder1_encoder2_joint_training: bool = True,
+        **kwargs,
+        
     ):
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -443,10 +431,8 @@
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
             batch_size = int((text_lengths + 1).sum())
-<<<<<<< HEAD:funasr/models/uniasr/e2e_uni_asr.py
 
-=======
->>>>>>> main:funasr/models/e2e_uni_asr.py
+
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
diff --git a/funasr/models/uniasr/template.yaml b/funasr/models/uniasr/template.yaml
new file mode 100644
index 0000000..f4815c1
--- /dev/null
+++ b/funasr/models/uniasr/template.yaml
@@ -0,0 +1,178 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: UniASR
+model_conf:
+    ctc_weight: 0.0
+    lsm_weight: 0.1
+    length_normalized_loss: true
+    predictor_weight: 1.0
+    decoder_attention_chunk_type: chunk
+    ctc_weight2: 0.0
+    predictor_weight2: 1.0
+    decoder_attention_chunk_type2: chunk
+    loss_weight_model1: 0.5
+
+# encoder
+encoder: SANMEncoderChunkOpt
+encoder_conf:
+    output_size: 320
+    attention_heads: 4
+    linear_units: 1280
+    num_blocks: 35
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+    chunk_size: [20, 60]
+    stride: [10, 40]
+    pad_left: [5, 10]
+    encoder_att_look_back_factor: [0, 0]
+    decoder_att_look_back_factor: [0, 0]
+
+# decoder
+decoder: FsmnDecoderSCAMAOpt
+decoder_conf:
+    attention_dim: 256
+    attention_heads: 4
+    linear_units: 1024
+    num_blocks: 12
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.1
+    src_attention_dropout_rate: 0.1
+    att_layer_num: 6
+    kernel_size: 11
+    concat_embeds: true
+
+predictor: CifPredictorV2
+predictor_conf:
+    idim: 320
+    threshold: 1.0
+    l_order: 1
+    r_order: 1
+
+encoder2: SANMEncoderChunkOpt
+encoder2_conf:
+    output_size: 320
+    attention_heads: 4
+    linear_units: 1280
+    num_blocks: 20
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 21
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+    chunk_size: [45, 70]
+    stride: [35, 50]
+    pad_left: [5, 10]
+    encoder_att_look_back_factor: [0, 0]
+    decoder_att_look_back_factor: [0, 0]
+
+decoder2: FsmnDecoderSCAMAOpt
+decoder2_conf:
+    attention_dim: 320
+    attention_heads: 4
+    linear_units: 1280
+    num_blocks: 12
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.1
+    src_attention_dropout_rate: 0.1
+    att_layer_num: 6
+    kernel_size: 11
+    concat_embeds: true
+
+predictor2: CifPredictorV2
+predictor2_conf:
+    idim: 320
+    threshold: 1.0
+    l_order: 1
+    r_order: 1
+
+stride_conv: stride_conv1d
+stride_conv_conf:
+    kernel_size: 2
+    stride: 2
+    pad: [0, 1]
+
+# frontend related
+frontend: WavFrontendOnline
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  keep_nbest_models: 10
+  avg_nbest_model: 5
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null

--
Gitblit v1.9.1