From cf2f14345aa2c4f168ee51c200b8081c748980b8 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 12 一月 2024 00:01:25 +0800
Subject: [PATCH] funasr1.0 fsmn-vad streaming

---
 funasr/models/fsmn_vad/encoder.py                                 |   18 
 examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh  |   11 
 funasr/models/fsmn_vad_streaming/model.py                         |  781 ++++++++++++++++++++++++++++++++
 funasr/models/fsmn_vad/model.py                                   |   58 --
 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py   |   11 
 funasr/models/paraformer_streaming/model.py                       |   15 
 funasr/models/fsmn_vad_streaming/__init__.py                      |    0 
 funasr/utils/load_utils.py                                        |   23 
 funasr/models/paraformer_streaming/template.yaml                  |  143 +++++
 funasr/models/fsmn_vad_streaming/encoder.py                       |  303 ++++++++++++
 examples/industrial_data_pretraining/paraformer_streaming/demo.py |    2 
 funasr/models/fsmn_vad_streaming/template.yaml                    |   62 ++
 12 files changed, 1,356 insertions(+), 71 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
new file mode 100644
index 0000000..2a157ee
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0")
+
+res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav")
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh
new file mode 100644
index 0000000..dedd14a
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh
@@ -0,0 +1,11 @@
+
+
+model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+model_revision="v2.0.0"
+
+python funasr/bin/inference.py \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++output_dir="./outputs/debug" \
++device="cpu" \
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 9923a04..6d464f2 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -12,8 +12,6 @@
 model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
 cache = {}
 res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
-            cache=cache,
-            is_final=True,
             chunk_size=chunk_size,
             encoder_chunk_look_back=encoder_chunk_look_back,
             decoder_chunk_look_back=decoder_chunk_look_back,
diff --git a/funasr/models/fsmn_vad/encoder.py b/funasr/models/fsmn_vad/encoder.py
index 54410ac..a0a379d 100755
--- a/funasr/models/fsmn_vad/encoder.py
+++ b/funasr/models/fsmn_vad/encoder.py
@@ -125,12 +125,12 @@
         self.affine = AffineTransform(proj_dim, linear_dim)
         self.relu = RectifiedLinear(linear_dim, linear_dim)
 
-    def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
+    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
         x1 = self.linear(input)  # B T D
         cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
-        if cache_layer_name not in in_cache:
-            in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
-        x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
+        if cache_layer_name not in cache:
+            cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+        x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
         x3 = self.affine(x2)
         x4 = self.relu(x3)
         return x4
@@ -140,10 +140,10 @@
     def __init__(self, *args):
         super(FsmnStack, self).__init__(*args)
 
-    def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
+    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
         x = input
         for module in self._modules.values():
-            x = module(x, in_cache)
+            x = module(x, cache)
         return x
 
 
@@ -199,19 +199,19 @@
     def forward(
             self,
             input: torch.Tensor,
-            in_cache: Dict[str, torch.Tensor]
+            cache: Dict[str, torch.Tensor]
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
         """
         Args:
             input (torch.Tensor): Input tensor (B, T, D)
-            in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
+            cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
             {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
         """
 
         x1 = self.in_linear1(input)
         x2 = self.in_linear2(x1)
         x3 = self.relu(x2)
-        x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
+        x4 = self.fsmn(x3, cache)  # self.cache will update automatically in self.fsmn
         x5 = self.out_linear1(x4)
         x6 = self.out_linear2(x5)
         x7 = self.softmax(x6)
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index f6e0488..1ed0773 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -333,8 +333,8 @@
                 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                 0.000001))
 
-    def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
-        scores = self.encoder(feats, in_cache).to('cpu')  # return B * T * D
+    def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
+        scores = self.encoder(feats, cache).to('cpu')  # return B * T * D
         assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
         self.vad_opts.nn_eval_block_size = scores.shape[1]
         self.frm_cnt += scores.shape[1]  # count total frames
@@ -493,14 +493,14 @@
 
         return frame_state
 
-    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
                 ):
-        if not in_cache:
+        if not cache:
             self.AllResetDetection()
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
-        self.ComputeScores(feats, in_cache)
+        self.ComputeScores(feats, cache)
         if not is_final:
             self.DetectCommonFrames()
         else:
@@ -521,7 +521,7 @@
         if is_final:
             # reset class variables and clear the dict for the next query
             self.AllResetDetection()
-        return segments, in_cache
+        return segments, cache
 
     def generate(self,
                  data_in,
@@ -561,7 +561,7 @@
         feats = speech
         feats_len = speech_lengths.max().item()
         waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
-        in_cache = kwargs.get("in_cache", {})
+        cache = kwargs.get("cache", {})
         batch_size = kwargs.get("batch_size", 1)
         step = min(feats_len, 6000)
         segments = [[]] * batch_size
@@ -576,11 +576,11 @@
                 "feats": feats[:, t_offset:t_offset + step, :],
                 "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
                 "is_final": is_final,
-                "in_cache": in_cache
+                "cache": cache
             }
 
 
-            segments_part, in_cache = self.forward(**batch)
+            segments_part, cache = self.forward(**batch)
             if segments_part:
                 for batch_num in range(0, batch_size):
                     segments[batch_num] += segments_part[batch_num]
@@ -603,46 +603,6 @@
             results.append(result_i)
  
         return results, meta_data
-
-    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
-                       is_final: bool = False, max_end_sil: int = 800
-                       ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
-        if not in_cache:
-            self.AllResetDetection()
-        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
-        self.waveform = waveform  # compute decibel for each frame
-
-        self.ComputeScores(feats, in_cache)
-        self.ComputeDecibel()
-        if not is_final:
-            self.DetectCommonFrames()
-        else:
-            self.DetectLastFrames()
-        segments = []
-        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
-            segment_batch = []
-            if len(self.output_data_buf) > 0:
-                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if not self.output_data_buf[i].contain_seg_start_point:
-                        continue
-                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
-                        continue
-                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
-                    if self.output_data_buf[i].contain_seg_end_point:
-                        end_ms = self.output_data_buf[i].end_ms
-                        self.next_seg = True
-                        self.output_data_buf_offset += 1
-                    else:
-                        end_ms = -1
-                        self.next_seg = False
-                    segment = [start_ms, end_ms]
-                    segment_batch.append(segment)
-            if segment_batch:
-                segments.append(segment_batch)
-        if is_final:
-            # reset class variables and clear the dict for the next query
-            self.AllResetDetection()
-        return segments, in_cache
 
     def DetectCommonFrames(self) -> int:
         if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
diff --git a/funasr/models/fsmn_vad_streaming/__init__.py b/funasr/models/fsmn_vad_streaming/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/__init__.py
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
new file mode 100755
index 0000000..ae91852
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -0,0 +1,303 @@
+from typing import Tuple, Dict
+import copy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from funasr.register import tables
+
+class LinearTransform(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(LinearTransform, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.linear = nn.Linear(input_dim, output_dim, bias=False)
+
+    def forward(self, input):
+        output = self.linear(input)
+
+        return output
+
+
+class AffineTransform(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(AffineTransform, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.linear = nn.Linear(input_dim, output_dim)
+
+    def forward(self, input):
+        output = self.linear(input)
+
+        return output
+
+
+class RectifiedLinear(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(RectifiedLinear, self).__init__()
+        self.dim = input_dim
+        self.relu = nn.ReLU()
+        self.dropout = nn.Dropout(0.1)
+
+    def forward(self, input):
+        out = self.relu(input)
+        return out
+
+
+class FSMNBlock(nn.Module):
+
+    def __init__(
+            self,
+            input_dim: int,
+            output_dim: int,
+            lorder=None,
+            rorder=None,
+            lstride=1,
+            rstride=1,
+    ):
+        super(FSMNBlock, self).__init__()
+
+        self.dim = input_dim
+
+        if lorder is None:
+            return
+
+        self.lorder = lorder
+        self.rorder = rorder
+        self.lstride = lstride
+        self.rstride = rstride
+
+        self.conv_left = nn.Conv2d(
+            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
+
+        if self.rorder > 0:
+            self.conv_right = nn.Conv2d(
+                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
+        else:
+            self.conv_right = None
+
+    def forward(self, input: torch.Tensor, cache: torch.Tensor):
+        x = torch.unsqueeze(input, 1)
+        x_per = x.permute(0, 3, 2, 1)  # B D T C
+        
+        cache = cache.to(x_per.device)
+        y_left = torch.cat((cache, x_per), dim=2)
+        cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
+        y_left = self.conv_left(y_left)
+        out = x_per + y_left
+
+        if self.conv_right is not None:
+            # maybe need to check
+            y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
+            y_right = y_right[:, :, self.rstride:, :]
+            y_right = self.conv_right(y_right)
+            out += y_right
+
+        out_per = out.permute(0, 3, 2, 1)
+        output = out_per.squeeze(1)
+
+        return output, cache
+
+
+class BasicBlock(nn.Module):
+    def __init__(self,
+                 linear_dim: int,
+                 proj_dim: int,
+                 lorder: int,
+                 rorder: int,
+                 lstride: int,
+                 rstride: int,
+                 stack_layer: int
+                 ):
+        super(BasicBlock, self).__init__()
+        self.lorder = lorder
+        self.rorder = rorder
+        self.lstride = lstride
+        self.rstride = rstride
+        self.stack_layer = stack_layer
+        self.linear = LinearTransform(linear_dim, proj_dim)
+        self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
+        self.affine = AffineTransform(proj_dim, linear_dim)
+        self.relu = RectifiedLinear(linear_dim, linear_dim)
+
+    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+        x1 = self.linear(input)  # B T D
+        cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+        if cache_layer_name not in cache:
+            cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+        x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+        x3 = self.affine(x2)
+        x4 = self.relu(x3)
+        return x4
+
+
+class FsmnStack(nn.Sequential):
+    def __init__(self, *args):
+        super(FsmnStack, self).__init__(*args)
+
+    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+        x = input
+        for module in self._modules.values():
+            x = module(x, cache)
+        return x
+
+
+'''
+FSMN net for keyword spotting
+input_dim:              input dimension
+linear_dim:             fsmn input dimensionll
+proj_dim:               fsmn projection dimension
+lorder:                 fsmn left order
+rorder:                 fsmn right order
+num_syn:                output dimension
+fsmn_layers:            no. of sequential fsmn layers
+'''
+
+@tables.register("encoder_classes", "FSMN")
+class FSMN(nn.Module):
+    def __init__(
+            self,
+            input_dim: int,
+            input_affine_dim: int,
+            fsmn_layers: int,
+            linear_dim: int,
+            proj_dim: int,
+            lorder: int,
+            rorder: int,
+            lstride: int,
+            rstride: int,
+            output_affine_dim: int,
+            output_dim: int
+    ):
+        super(FSMN, self).__init__()
+
+        self.input_dim = input_dim
+        self.input_affine_dim = input_affine_dim
+        self.fsmn_layers = fsmn_layers
+        self.linear_dim = linear_dim
+        self.proj_dim = proj_dim
+        self.output_affine_dim = output_affine_dim
+        self.output_dim = output_dim
+
+        self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+        self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+        self.relu = RectifiedLinear(linear_dim, linear_dim)
+        self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+                                range(fsmn_layers)])
+        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def fuse_modules(self):
+        pass
+
+    def forward(
+            self,
+            input: torch.Tensor,
+            cache: Dict[str, torch.Tensor]
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+        """
+        Args:
+            input (torch.Tensor): Input tensor (B, T, D)
+            cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
+            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+        """
+
+        x1 = self.in_linear1(input)
+        x2 = self.in_linear2(x1)
+        x3 = self.relu(x2)
+        x4 = self.fsmn(x3, cache)  # self.cache will update automatically in self.fsmn
+        x5 = self.out_linear1(x4)
+        x6 = self.out_linear2(x5)
+        x7 = self.softmax(x6)
+
+        return x7
+
+
+'''
+one deep fsmn layer
+dimproj:                projection dimension, input and output dimension of memory blocks
+dimlinear:              dimension of mapping layer
+lorder:                 left order
+rorder:                 right order
+lstride:                left stride
+rstride:                right stride
+'''
+
+@tables.register("encoder_classes", "DFSMN")
+class DFSMN(nn.Module):
+
+    def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
+        super(DFSMN, self).__init__()
+
+        self.lorder = lorder
+        self.rorder = rorder
+        self.lstride = lstride
+        self.rstride = rstride
+
+        self.expand = AffineTransform(dimproj, dimlinear)
+        self.shrink = LinearTransform(dimlinear, dimproj)
+
+        self.conv_left = nn.Conv2d(
+            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
+
+        if rorder > 0:
+            self.conv_right = nn.Conv2d(
+                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
+        else:
+            self.conv_right = None
+
+    def forward(self, input):
+        f1 = F.relu(self.expand(input))
+        p1 = self.shrink(f1)
+
+        x = torch.unsqueeze(p1, 1)
+        x_per = x.permute(0, 3, 2, 1)
+
+        y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
+        if self.conv_right is not None:
+            y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
+            y_right = y_right[:, :, self.rstride:, :]
+            out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
+        else:
+            out = x_per + self.conv_left(y_left)
+
+        out1 = out.permute(0, 3, 2, 1)
+        output = input + out1.squeeze(1)
+
+        return output
+
+
+'''
+build stacked dfsmn layers
+'''
+
+
+def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
+    repeats = [
+        nn.Sequential(
+            DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
+        for i in range(fsmn_layers)
+    ]
+
+    return nn.Sequential(*repeats)
+
+
+if __name__ == '__main__':
+    fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
+    print(fsmn)
+
+    num_params = sum(p.numel() for p in fsmn.parameters())
+    print('the number of model params: {}'.format(num_params))
+    x = torch.zeros(128, 200, 400)  # batch-size * time * dim
+    y, _ = fsmn(x)  # batch-size * time * dim
+    print('input shape: {}'.format(x.shape))
+    print('output shape: {}'.format(y.shape))
+
+    print(fsmn.to_kaldi_net())
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
new file mode 100644
index 0000000..4c7e943
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -0,0 +1,781 @@
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+import logging
+import os
+import json
+import torch
+from torch import nn
+import math
+from typing import Optional
+import time
+from funasr.register import tables
+from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
+from torch.nn.utils.rnn import pad_sequence
+
+class VadStateMachine(Enum):
+    kVadInStateStartPointNotDetected = 1
+    kVadInStateInSpeechSegment = 2
+    kVadInStateEndPointDetected = 3
+
+
+class FrameState(Enum):
+    kFrameStateInvalid = -1
+    kFrameStateSpeech = 1
+    kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+    kChangeStateSpeech2Speech = 0
+    kChangeStateSpeech2Sil = 1
+    kChangeStateSil2Sil = 2
+    kChangeStateSil2Speech = 3
+    kChangeStateNoBegin = 4
+    kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+    kVadSingleUtteranceDetectMode = 0
+    kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(
+            self,
+            sample_rate: int = 16000,
+            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+            snr_mode: int = 0,
+            max_end_silence_time: int = 800,
+            max_start_silence_time: int = 3000,
+            do_start_point_detection: bool = True,
+            do_end_point_detection: bool = True,
+            window_size_ms: int = 200,
+            sil_to_speech_time_thres: int = 150,
+            speech_to_sil_time_thres: int = 150,
+            speech_2_noise_ratio: float = 1.0,
+            do_extend: int = 1,
+            lookback_time_start_point: int = 200,
+            lookahead_time_end_point: int = 100,
+            max_single_segment_time: int = 60000,
+            nn_eval_block_size: int = 8,
+            dcd_block_size: int = 4,
+            snr_thres: int = -100.0,
+            noise_frame_num_used_for_snr: int = 100,
+            decibel_thres: int = -100.0,
+            speech_noise_thres: float = 0.6,
+            fe_prior_thres: float = 1e-4,
+            silence_pdf_num: int = 1,
+            sil_pdf_ids: List[int] = [0],
+            speech_noise_thresh_low: float = -0.1,
+            speech_noise_thresh_high: float = 0.3,
+            output_frame_probs: bool = False,
+            frame_in_ms: int = 10,
+            frame_length_ms: int = 25,
+            **kwargs,
+    ):
+        self.sample_rate = sample_rate
+        self.detect_mode = detect_mode
+        self.snr_mode = snr_mode
+        self.max_end_silence_time = max_end_silence_time
+        self.max_start_silence_time = max_start_silence_time
+        self.do_start_point_detection = do_start_point_detection
+        self.do_end_point_detection = do_end_point_detection
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time_thres = sil_to_speech_time_thres
+        self.speech_to_sil_time_thres = speech_to_sil_time_thres
+        self.speech_2_noise_ratio = speech_2_noise_ratio
+        self.do_extend = do_extend
+        self.lookback_time_start_point = lookback_time_start_point
+        self.lookahead_time_end_point = lookahead_time_end_point
+        self.max_single_segment_time = max_single_segment_time
+        self.nn_eval_block_size = nn_eval_block_size
+        self.dcd_block_size = dcd_block_size
+        self.snr_thres = snr_thres
+        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+        self.decibel_thres = decibel_thres
+        self.speech_noise_thres = speech_noise_thres
+        self.fe_prior_thres = fe_prior_thres
+        self.silence_pdf_num = silence_pdf_num
+        self.sil_pdf_ids = sil_pdf_ids
+        self.speech_noise_thresh_low = speech_noise_thresh_low
+        self.speech_noise_thresh_high = speech_noise_thresh_high
+        self.output_frame_probs = output_frame_probs
+        self.frame_in_ms = frame_in_ms
+        self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+    def Reset(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+
+class E2EVadFrameProb(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self):
+        self.noise_prob = 0.0
+        self.speech_prob = 0.0
+        self.score = 0.0
+        self.frame_id = 0
+        self.frm_state = 0
+
+
+class WindowDetector(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+                 speech_to_sil_time: int, frame_size_ms: int):
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time = sil_to_speech_time
+        self.speech_to_sil_time = speech_to_sil_time
+        self.frame_size_ms = frame_size_ms
+
+        self.win_size_frame = int(window_size_ms / frame_size_ms)
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
+
+        self.cur_win_pos = 0
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def Reset(self) -> None:
+        self.cur_win_pos = 0
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def GetWinSize(self) -> int:
+        return int(self.win_size_frame)
+
+    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+        cur_frame_state = FrameState.kFrameStateSil
+        if frameState == FrameState.kFrameStateSpeech:
+            cur_frame_state = 1
+        elif frameState == FrameState.kFrameStateSil:
+            cur_frame_state = 0
+        else:
+            return AudioChangeState.kChangeStateInvalid
+        self.win_sum -= self.win_state[self.cur_win_pos]
+        self.win_sum += cur_frame_state
+        self.win_state[self.cur_win_pos] = cur_frame_state
+        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSpeech
+            return AudioChangeState.kChangeStateSil2Speech
+
+        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSil
+            return AudioChangeState.kChangeStateSpeech2Sil
+
+        if self.pre_frame_state == FrameState.kFrameStateSil:
+            return AudioChangeState.kChangeStateSil2Sil
+        if self.pre_frame_state == FrameState.kFrameStateSpeech:
+            return AudioChangeState.kChangeStateSpeech2Speech
+        return AudioChangeState.kChangeStateInvalid
+
+    def FrameSizeMs(self) -> int:
+        return int(self.frame_size_ms)
+
+
+@tables.register("model_classes", "FsmnVADStreaming")
+class FsmnVADStreaming(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self,
+                 encoder: str = None,
+                 encoder_conf: Optional[Dict] = None,
+                 vad_post_args: Dict[str, Any] = None,
+                 **kwargs,
+                 ):
+        super().__init__()
+        self.vad_opts = VADXOptions(**kwargs)
+        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
+                                               self.vad_opts.sil_to_speech_time_thres,
+                                               self.vad_opts.speech_to_sil_time_thres,
+                                               self.vad_opts.frame_in_ms)
+        
+        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder = encoder_class(**encoder_conf)
+        self.encoder = encoder
+        # init variables
+        self.data_buf_start_frame = 0
+        self.frm_cnt = 0
+        self.latest_confirmed_speech_frame = 0
+        self.lastest_confirmed_silence_frame = -1
+        self.continous_silence_frame_count = 0
+        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        self.confirmed_start_frame = -1
+        self.confirmed_end_frame = -1
+        self.number_end_time_detected = 0
+        self.sil_frame = 0
+        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+        self.noise_average_decibel = -100.0
+        self.pre_end_silence_detected = False
+        self.next_seg = True
+
+        self.output_data_buf = []
+        self.output_data_buf_offset = 0
+        self.frame_probs = []
+        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+        self.speech_noise_thres = self.vad_opts.speech_noise_thres
+        self.scores = None
+        self.max_time_out = False
+        self.decibel = []
+        self.data_buf = None
+        self.data_buf_all = None
+        self.waveform = None
+        self.last_drop_frames = 0
+
+    def AllResetDetection(self):
+        self.data_buf_start_frame = 0
+        self.frm_cnt = 0
+        self.latest_confirmed_speech_frame = 0
+        self.lastest_confirmed_silence_frame = -1
+        self.continous_silence_frame_count = 0
+        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        self.confirmed_start_frame = -1
+        self.confirmed_end_frame = -1
+        self.number_end_time_detected = 0
+        self.sil_frame = 0
+        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+        self.noise_average_decibel = -100.0
+        self.pre_end_silence_detected = False
+        self.next_seg = True
+
+        self.output_data_buf = []
+        self.output_data_buf_offset = 0
+        self.frame_probs = []
+        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+        self.speech_noise_thres = self.vad_opts.speech_noise_thres
+        self.scores = None
+        self.max_time_out = False
+        self.decibel = []
+        self.data_buf = None
+        self.data_buf_all = None
+        self.waveform = None
+        self.last_drop_frames = 0
+        self.windows_detector.Reset()
+
+    def ResetDetection(self):
+        self.continous_silence_frame_count = 0
+        self.latest_confirmed_speech_frame = 0
+        self.lastest_confirmed_silence_frame = -1
+        self.confirmed_start_frame = -1
+        self.confirmed_end_frame = -1
+        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        self.windows_detector.Reset()
+        self.sil_frame = 0
+        self.frame_probs = []
+
+        if self.output_data_buf:
+            assert self.output_data_buf[-1].contain_seg_end_point == True
+            drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+            real_drop_frames = drop_frames - self.last_drop_frames
+            self.last_drop_frames = drop_frames
+            self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+            self.decibel = self.decibel[real_drop_frames:]
+            self.scores = self.scores[:, real_drop_frames:, :]
+
+    def ComputeDecibel(self) -> None:
+        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+        if self.data_buf_all is None:
+            self.data_buf_all = self.waveform[0]  # self.data_buf is pointed to self.waveform[0]
+            self.data_buf = self.data_buf_all
+        else:
+            self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
+        for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
+            self.decibel.append(
+                10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
+                                0.000001))
+
+    def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
+        scores = self.encoder(feats, cache).to('cpu')  # return B * T * D
+        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
+        self.vad_opts.nn_eval_block_size = scores.shape[1]
+        self.frm_cnt += scores.shape[1]  # count total frames
+        if self.scores is None:
+            self.scores = scores  # the first calculation
+        else:
+            self.scores = torch.cat((self.scores, scores), dim=1)
+
+    def PopDataBufTillFrame(self, frame_idx: int) -> None:  # need check again
+        while self.data_buf_start_frame < frame_idx:
+            if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+                self.data_buf_start_frame += 1
+                self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
+                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+
+    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
+                           last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
+        self.PopDataBufTillFrame(start_frm)
+        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
+        if last_frm_is_end_point:
+            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
+                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+            expected_sample_number += int(extra_sample)
+        if end_point_is_sent_end:
+            expected_sample_number = max(expected_sample_number, len(self.data_buf))
+        if len(self.data_buf) < expected_sample_number:
+            print('error in calling pop data_buf\n')
+
+        if len(self.output_data_buf) == 0 or first_frm_is_start_point:
+            self.output_data_buf.append(E2EVadSpeechBufWithDoa())
+            self.output_data_buf[-1].Reset()
+            self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+            self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
+            self.output_data_buf[-1].doa = 0
+        cur_seg = self.output_data_buf[-1]
+        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+            print('warning\n')
+        out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+        data_to_pop = 0
+        if end_point_is_sent_end:
+            data_to_pop = expected_sample_number
+        else:
+            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+        if data_to_pop > len(self.data_buf):
+            print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
+            data_to_pop = len(self.data_buf)
+            expected_sample_number = len(self.data_buf)
+
+        cur_seg.doa = 0
+        for sample_cpy_out in range(0, data_to_pop):
+            # cur_seg.buffer[out_pos ++] = data_buf_.back();
+            out_pos += 1
+        for sample_cpy_out in range(data_to_pop, expected_sample_number):
+            # cur_seg.buffer[out_pos++] = data_buf_.back()
+            out_pos += 1
+        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+            print('Something wrong with the VAD algorithm\n')
+        self.data_buf_start_frame += frm_cnt
+        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+        if first_frm_is_start_point:
+            cur_seg.contain_seg_start_point = True
+        if last_frm_is_end_point:
+            cur_seg.contain_seg_end_point = True
+
+    def OnSilenceDetected(self, valid_frame: int):
+        self.lastest_confirmed_silence_frame = valid_frame
+        if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+            self.PopDataBufTillFrame(valid_frame)
+        # silence_detected_callback_
+        # pass
+
+    def OnVoiceDetected(self, valid_frame: int) -> None:
+        self.latest_confirmed_speech_frame = valid_frame
+        self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
+
+    def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
+        if self.vad_opts.do_start_point_detection:
+            pass
+        if self.confirmed_start_frame != -1:
+            print('not reset vad properly\n')
+        else:
+            self.confirmed_start_frame = start_frame
+
+        if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+            self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
+
+    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
+        for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
+            self.OnVoiceDetected(t)
+        if self.vad_opts.do_end_point_detection:
+            pass
+        if self.confirmed_end_frame != -1:
+            print('not reset vad properly\n')
+        else:
+            self.confirmed_end_frame = end_frame
+        if not fake_result:
+            self.sil_frame = 0
+            self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
+        self.number_end_time_detected += 1
+
+    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
+        if is_final_frame:
+            self.OnVoiceEnd(cur_frm_idx, False, True)
+            self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+
+    def GetLatency(self) -> int:
+        return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
+
+    def LatencyFrmNumAtStartPoint(self) -> int:
+        vad_latency = self.windows_detector.GetWinSize()
+        if self.vad_opts.do_extend:
+            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+        return vad_latency
+
+    def GetFrameState(self, t: int):
+        frame_state = FrameState.kFrameStateInvalid
+        cur_decibel = self.decibel[t]
+        cur_snr = cur_decibel - self.noise_average_decibel
+        # for each frame, calc log posterior probability of each state
+        if cur_decibel < self.vad_opts.decibel_thres:
+            frame_state = FrameState.kFrameStateSil
+            self.DetectOneFrame(frame_state, t, False)
+            return frame_state
+
+        sum_score = 0.0
+        noise_prob = 0.0
+        assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
+        if len(self.sil_pdf_ids) > 0:
+            assert len(self.scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+            sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
+            sum_score = sum(sil_pdf_scores)
+            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+            total_score = 1.0
+            sum_score = total_score - sum_score
+        speech_prob = math.log(sum_score)
+        if self.vad_opts.output_frame_probs:
+            frame_prob = E2EVadFrameProb()
+            frame_prob.noise_prob = noise_prob
+            frame_prob.speech_prob = speech_prob
+            frame_prob.score = sum_score
+            frame_prob.frame_id = t
+            self.frame_probs.append(frame_prob)
+        if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
+            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+                frame_state = FrameState.kFrameStateSpeech
+            else:
+                frame_state = FrameState.kFrameStateSil
+        else:
+            frame_state = FrameState.kFrameStateSil
+            if self.noise_average_decibel < -99.9:
+                self.noise_average_decibel = cur_decibel
+            else:
+                self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
+                        self.vad_opts.noise_frame_num_used_for_snr
+                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
+
+        return frame_state
+
+    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
+                is_final: bool = False
+                ):
+        if not cache:
+            self.AllResetDetection()
+        self.waveform = waveform  # compute decibel for each frame
+        self.ComputeDecibel()
+        self.ComputeScores(feats, cache)
+        if not is_final:
+            self.DetectCommonFrames()
+        else:
+            self.DetectLastFrames()
+        segments = []
+        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
+            segment_batch = []
+            if len(self.output_data_buf) > 0:
+                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+                    if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+                        i].contain_seg_end_point):
+                        continue
+                    segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
+                    segment_batch.append(segment)
+                    self.output_data_buf_offset += 1  # need update this parameter
+            if segment_batch:
+                segments.append(segment_batch)
+        if is_final:
+            # reset class variables and clear the dict for the next query
+            self.AllResetDetection()
+        return segments, cache
+
+    def init_cache(self, cache: dict = {}, **kwargs):
+        cache["frontend"] = {}
+        cache["prev_samples"] = torch.empty(0)
+        
+        return cache
+    def generate(self,
+                 data_in,
+                 data_lengths=None,
+                 key: list = None,
+                 tokenizer=None,
+                 frontend=None,
+                 cache: dict = {},
+                 **kwargs,
+                 ):
+    
+        if len(cache) == 0:
+            self.init_cache(cache, **kwargs)
+
+        meta_data = {}
+        chunk_size = kwargs.get("chunk_size", 50) # 50ms
+        chunk_stride_samples = chunk_size * 16
+
+        time1 = time.perf_counter()
+        cfg = {"is_final": kwargs.get("is_final", False)}
+        audio_sample_list = load_audio_text_image_video(data_in,
+                                                        fs=frontend.fs,
+                                                        audio_fs=kwargs.get("fs", 16000),
+                                                        data_type=kwargs.get("data_type", "sound"),
+                                                        tokenizer=tokenizer,
+                                                        **cfg,
+                                                        )
+        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
+
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        assert len(audio_sample_list) == 1, "batch_size must be set 1"
+
+        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+
+        n = len(audio_sample) // chunk_stride_samples + int(_is_final)
+        m = len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))
+        tokens = []
+        for i in range(n):
+            kwargs["is_final"] = _is_final and i == n - 1
+            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
+    
+            # extract fbank feats
+            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend, cache=cache["frontend"],
+                                                   is_final=kwargs["is_final"])
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        meta_data = {}
+        audio_sample_list = [data_in]
+        if isinstance(data_in, torch.Tensor):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data[
+                "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+        # b. Forward Encoder streaming
+        t_offset = 0
+        feats = speech
+        feats_len = speech_lengths.max().item()
+        waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
+        cache = kwargs.get("cache", {})
+        batch_size = kwargs.get("batch_size", 1)
+        step = min(feats_len, 6000)
+        segments = [[]] * batch_size
+
+        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
+            if t_offset + step >= feats_len - 1:
+                step = feats_len - t_offset
+                is_final = True
+            else:
+                is_final = False
+            batch = {
+                "feats": feats[:, t_offset:t_offset + step, :],
+                "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
+                "is_final": is_final,
+                "cache": cache
+            }
+
+
+            segments_part, cache = self.forward(**batch)
+            if segments_part:
+                for batch_num in range(0, batch_size):
+                    segments[batch_num] += segments_part[batch_num]
+
+        ibest_writer = None
+        if ibest_writer is None and kwargs.get("output_dir") is not None:
+            writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = writer[f"{1}best_recog"]
+
+        results = []
+        for i in range(batch_size):
+            
+            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+                results[i] = json.dumps(results[i])
+                
+            if ibest_writer is not None:
+                ibest_writer["text"][key[i]] = segments[i]
+
+            result_i = {"key": key[i], "value": segments[i]}
+            results.append(result_i)
+ 
+        return results, meta_data
+
+
+    def DetectCommonFrames(self) -> int:
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
+            self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+
+        return 0
+
+    def DetectLastFrames(self) -> int:
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
+            if i != 0:
+                self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+            else:
+                self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
+
+        return 0
+
+    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
+        tmp_cur_frm_state = FrameState.kFrameStateInvalid
+        if cur_frm_state == FrameState.kFrameStateSpeech:
+            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+                tmp_cur_frm_state = FrameState.kFrameStateSpeech
+            else:
+                tmp_cur_frm_state = FrameState.kFrameStateSil
+        elif cur_frm_state == FrameState.kFrameStateSil:
+            tmp_cur_frm_state = FrameState.kFrameStateSil
+        state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
+        frm_shift_in_ms = self.vad_opts.frame_in_ms
+        if AudioChangeState.kChangeStateSil2Speech == state_change:
+            silence_frame_count = self.continous_silence_frame_count
+            self.continous_silence_frame_count = 0
+            self.pre_end_silence_detected = False
+            start_frame = 0
+            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+                self.OnVoiceStart(start_frame)
+                self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+                for t in range(start_frame + 1, cur_frm_idx + 1):
+                    self.OnVoiceDetected(t)
+            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
+                    self.OnVoiceDetected(t)
+                if cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+            self.continous_silence_frame_count = 0
+            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                pass
+            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+            self.continous_silence_frame_count = 0
+            if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.max_time_out = True
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSil2Sil == state_change:
+            self.continous_silence_frame_count += 1
+            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                # silence timeout, return zero length decision
+                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
+                        self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+                        or (is_final_frame and self.number_end_time_detected == 0):
+                    for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
+                        self.OnSilenceDetected(t)
+                    self.OnVoiceStart(0, True)
+                    self.OnVoiceEnd(0, True, False);
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                else:
+                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
+                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
+                    lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
+                    if self.vad_opts.do_extend:
+                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
+                        lookback_frame -= 1
+                        lookback_frame = max(0, lookback_frame)
+                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif self.vad_opts.do_extend and not is_final_frame:
+                    if self.continous_silence_frame_count <= int(
+                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
+                        self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
+                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
+            self.ResetDetection()
+
+
+
diff --git a/funasr/models/fsmn_vad_streaming/template.yaml b/funasr/models/fsmn_vad_streaming/template.yaml
new file mode 100644
index 0000000..e8a3a4f
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/template.yaml
@@ -0,0 +1,62 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: FsmnVADStreaming
+model_conf:
+    sample_rate: 16000
+    detect_mode: 1 
+    snr_mode: 0
+    max_end_silence_time: 800
+    max_start_silence_time: 3000
+    do_start_point_detection: True
+    do_end_point_detection: True
+    window_size_ms: 200
+    sil_to_speech_time_thres: 150
+    speech_to_sil_time_thres: 150
+    speech_2_noise_ratio: 1.0
+    do_extend: 1
+    lookback_time_start_point: 200
+    lookahead_time_end_point: 100
+    max_single_segment_time: 60000
+    snr_thres: -100.0
+    noise_frame_num_used_for_snr: 100
+    decibel_thres: -100.0
+    speech_noise_thres: 0.6
+    fe_prior_thres: 0.0001
+    silence_pdf_num: 1
+    sil_pdf_ids: [0]
+    speech_noise_thresh_low: -0.1
+    speech_noise_thresh_high: 0.3
+    output_frame_probs: False
+    frame_in_ms: 10
+    frame_length_ms: 25
+    
+encoder: FSMN
+encoder_conf:
+    input_dim: 400
+    input_affine_dim: 140
+    fsmn_layers: 4
+    linear_dim: 250
+    proj_dim: 128
+    lorder: 20
+    rorder: 0
+    lstride: 1
+    rstride: 0
+    output_affine_dim: 140
+    output_dim: 248
+
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    dither: 0.0
+    lfr_m: 5
+    lfr_n: 1
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 927b091..fdc0c93 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -519,16 +519,23 @@
 
 		if len(cache) == 0:
 			self.init_cache(cache, **kwargs)
-		_is_final = kwargs.get("is_final", False)
+		
 		
 		meta_data = {}
 		chunk_size = kwargs.get("chunk_size", [0, 10, 5])
 		chunk_stride_samples = chunk_size[1] * 960  # 600ms
 		
 		time1 = time.perf_counter()
-		audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
-		                                                data_type=kwargs.get("data_type", "sound"),
-		                                                tokenizer=tokenizer)
+		cfg = {"is_final": kwargs.get("is_final", False)}
+		audio_sample_list = load_audio_text_image_video(data_in,
+														fs=frontend.fs,
+														audio_fs=kwargs.get("fs", 16000),
+														data_type=kwargs.get("data_type", "sound"),
+														tokenizer=tokenizer,
+														**cfg,
+														)
+		_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
+		
 		time2 = time.perf_counter()
 		meta_data["load_data"] = f"{time2 - time1:0.3f}"
 		assert len(audio_sample_list) == 1, "batch_size must be set 1"
diff --git a/funasr/models/paraformer_streaming/template.yaml b/funasr/models/paraformer_streaming/template.yaml
new file mode 100644
index 0000000..d1300ac
--- /dev/null
+++ b/funasr/models/paraformer_streaming/template.yaml
@@ -0,0 +1,143 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: ParaformerStreaming
+model_conf:
+    ctc_weight: 0.0
+    lsm_weight: 0.1
+    length_normalized_loss: true
+    predictor_weight: 1.0
+    predictor_bias: 1
+    sampling_ratio: 0.75
+
+# encoder
+encoder: SANMEncoderChunkOpt
+encoder_conf:
+    output_size: 512
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 50
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe_online
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+    chunk_size:
+    - 12
+    - 15
+    stride:
+    - 8
+    - 10
+    pad_left:
+    - 0
+    - 0
+    encoder_att_look_back_factor:
+    - 4
+    - 4
+    decoder_att_look_back_factor:
+    - 1
+    - 1
+
+# decoder
+decoder: ParaformerSANMDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 16
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.1
+    src_attention_dropout_rate: 0.1
+    att_layer_num: 16
+    kernel_size: 11
+    sanm_shfit: 5
+
+predictor: CifPredictorV2
+predictor_conf:
+    idim: 512
+    threshold: 1.0
+    l_order: 1
+    r_order: 1
+    tail_threshold: 0.45
+
+# frontend related
+frontend: WavFrontendOnline
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  val_scheduler_criterion:
+      - valid
+      - acc
+  best_model_criterion:
+  -   - valid
+      - acc
+      - max
+  keep_nbest_models: 10
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index bb9cf01..638e0ac 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -16,7 +16,7 @@
 
 
 
-def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None):
+def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
 	if isinstance(data_or_path_or_list, (list, tuple)):
 		if data_type is not None and isinstance(data_type, (list, tuple)):
 
@@ -26,20 +26,29 @@
 				
 				for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
 					
-					data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer)
+					data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
 					data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
 
 			return data_or_path_or_list_ret
 		else:
-			return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type) for audio in data_or_path_or_list]
-	if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'):
+			return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
+	
+	if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
 		data_or_path_or_list = download_from_url(data_or_path_or_list)
-	if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list):
+	
+	if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
 		if data_type is None or data_type == "sound":
 			data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
 			data_or_path_or_list = data_or_path_or_list[0, :]
-		# elif data_type == "text" and tokenizer is not None:
-		# 	data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+		elif data_type == "text" and tokenizer is not None:
+			data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+		elif data_type == "image": # undo
+			pass
+		elif data_type == "video": # undo
+			pass
+		
+		# if data_in is a file or url, set is_final=True
+		kwargs["is_final"] = True
 	elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
 		data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
 	elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point

--
Gitblit v1.9.1