From cf2f14345aa2c4f168ee51c200b8081c748980b8 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 12 一月 2024 00:01:25 +0800
Subject: [PATCH] funasr1.0 fsmn-vad streaming
---
funasr/models/fsmn_vad/encoder.py | 18
examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh | 11
funasr/models/fsmn_vad_streaming/model.py | 781 ++++++++++++++++++++++++++++++++
funasr/models/fsmn_vad/model.py | 58 --
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py | 11
funasr/models/paraformer_streaming/model.py | 15
funasr/models/fsmn_vad_streaming/__init__.py | 0
funasr/utils/load_utils.py | 23
funasr/models/paraformer_streaming/template.yaml | 143 +++++
funasr/models/fsmn_vad_streaming/encoder.py | 303 ++++++++++++
examples/industrial_data_pretraining/paraformer_streaming/demo.py | 2
funasr/models/fsmn_vad_streaming/template.yaml | 62 ++
12 files changed, 1,356 insertions(+), 71 deletions(-)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
new file mode 100644
index 0000000..2a157ee
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0")
+
+res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav")
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh
new file mode 100644
index 0000000..dedd14a
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh
@@ -0,0 +1,11 @@
+
+
+model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+model_revision="v2.0.0"
+
+python funasr/bin/inference.py \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++output_dir="./outputs/debug" \
++device="cpu" \
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 9923a04..6d464f2 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -12,8 +12,6 @@
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
cache = {}
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
- cache=cache,
- is_final=True,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
diff --git a/funasr/models/fsmn_vad/encoder.py b/funasr/models/fsmn_vad/encoder.py
index 54410ac..a0a379d 100755
--- a/funasr/models/fsmn_vad/encoder.py
+++ b/funasr/models/fsmn_vad/encoder.py
@@ -125,12 +125,12 @@
self.affine = AffineTransform(proj_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
- def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
+ def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
x1 = self.linear(input) # B T D
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
- if cache_layer_name not in in_cache:
- in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
- x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
+ if cache_layer_name not in cache:
+ cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+ x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
x3 = self.affine(x2)
x4 = self.relu(x3)
return x4
@@ -140,10 +140,10 @@
def __init__(self, *args):
super(FsmnStack, self).__init__(*args)
- def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
+ def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
x = input
for module in self._modules.values():
- x = module(x, in_cache)
+ x = module(x, cache)
return x
@@ -199,19 +199,19 @@
def forward(
self,
input: torch.Tensor,
- in_cache: Dict[str, torch.Tensor]
+ cache: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
- in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
+ cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
- x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
+ x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
x7 = self.softmax(x6)
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index f6e0488..1ed0773 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -333,8 +333,8 @@
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
0.000001))
- def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
- scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D
+ def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
+ scores = self.encoder(feats, cache).to('cpu') # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
@@ -493,14 +493,14 @@
return frame_state
- def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
):
- if not in_cache:
+ if not cache:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
- self.ComputeScores(feats, in_cache)
+ self.ComputeScores(feats, cache)
if not is_final:
self.DetectCommonFrames()
else:
@@ -521,7 +521,7 @@
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
- return segments, in_cache
+ return segments, cache
def generate(self,
data_in,
@@ -561,7 +561,7 @@
feats = speech
feats_len = speech_lengths.max().item()
waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
- in_cache = kwargs.get("in_cache", {})
+ cache = kwargs.get("cache", {})
batch_size = kwargs.get("batch_size", 1)
step = min(feats_len, 6000)
segments = [[]] * batch_size
@@ -576,11 +576,11 @@
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final": is_final,
- "in_cache": in_cache
+ "cache": cache
}
- segments_part, in_cache = self.forward(**batch)
+ segments_part, cache = self.forward(**batch)
if segments_part:
for batch_num in range(0, batch_size):
segments[batch_num] += segments_part[batch_num]
@@ -603,46 +603,6 @@
results.append(result_i)
return results, meta_data
-
- def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
- is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
- if not in_cache:
- self.AllResetDetection()
- self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
- self.waveform = waveform # compute decibel for each frame
-
- self.ComputeScores(feats, in_cache)
- self.ComputeDecibel()
- if not is_final:
- self.DetectCommonFrames()
- else:
- self.DetectLastFrames()
- segments = []
- for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
- segment_batch = []
- if len(self.output_data_buf) > 0:
- for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
- if not self.output_data_buf[i].contain_seg_start_point:
- continue
- if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
- continue
- start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
- if self.output_data_buf[i].contain_seg_end_point:
- end_ms = self.output_data_buf[i].end_ms
- self.next_seg = True
- self.output_data_buf_offset += 1
- else:
- end_ms = -1
- self.next_seg = False
- segment = [start_ms, end_ms]
- segment_batch.append(segment)
- if segment_batch:
- segments.append(segment_batch)
- if is_final:
- # reset class variables and clear the dict for the next query
- self.AllResetDetection()
- return segments, in_cache
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
diff --git a/funasr/models/fsmn_vad_streaming/__init__.py b/funasr/models/fsmn_vad_streaming/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/__init__.py
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
new file mode 100755
index 0000000..ae91852
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -0,0 +1,303 @@
+from typing import Tuple, Dict
+import copy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from funasr.register import tables
+
+class LinearTransform(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(LinearTransform, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
+
+ def forward(self, input):
+ output = self.linear(input)
+
+ return output
+
+
+class AffineTransform(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(AffineTransform, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.linear = nn.Linear(input_dim, output_dim)
+
+ def forward(self, input):
+ output = self.linear(input)
+
+ return output
+
+
+class RectifiedLinear(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(RectifiedLinear, self).__init__()
+ self.dim = input_dim
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, input):
+ out = self.relu(input)
+ return out
+
+
+class FSMNBlock(nn.Module):
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ lorder=None,
+ rorder=None,
+ lstride=1,
+ rstride=1,
+ ):
+ super(FSMNBlock, self).__init__()
+
+ self.dim = input_dim
+
+ if lorder is None:
+ return
+
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+
+ self.conv_left = nn.Conv2d(
+ self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
+
+ if self.rorder > 0:
+ self.conv_right = nn.Conv2d(
+ self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
+ else:
+ self.conv_right = None
+
+ def forward(self, input: torch.Tensor, cache: torch.Tensor):
+ x = torch.unsqueeze(input, 1)
+ x_per = x.permute(0, 3, 2, 1) # B D T C
+
+ cache = cache.to(x_per.device)
+ y_left = torch.cat((cache, x_per), dim=2)
+ cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
+ y_left = self.conv_left(y_left)
+ out = x_per + y_left
+
+ if self.conv_right is not None:
+ # maybe need to check
+ y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
+ y_right = y_right[:, :, self.rstride:, :]
+ y_right = self.conv_right(y_right)
+ out += y_right
+
+ out_per = out.permute(0, 3, 2, 1)
+ output = out_per.squeeze(1)
+
+ return output, cache
+
+
+class BasicBlock(nn.Module):
+ def __init__(self,
+ linear_dim: int,
+ proj_dim: int,
+ lorder: int,
+ rorder: int,
+ lstride: int,
+ rstride: int,
+ stack_layer: int
+ ):
+ super(BasicBlock, self).__init__()
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+ self.stack_layer = stack_layer
+ self.linear = LinearTransform(linear_dim, proj_dim)
+ self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
+ self.affine = AffineTransform(proj_dim, linear_dim)
+ self.relu = RectifiedLinear(linear_dim, linear_dim)
+
+ def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+ x1 = self.linear(input) # B T D
+ cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ if cache_layer_name not in cache:
+ cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+ x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+ x3 = self.affine(x2)
+ x4 = self.relu(x3)
+ return x4
+
+
+class FsmnStack(nn.Sequential):
+ def __init__(self, *args):
+ super(FsmnStack, self).__init__(*args)
+
+ def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+ x = input
+ for module in self._modules.values():
+ x = module(x, cache)
+ return x
+
+
+'''
+FSMN net for keyword spotting
+input_dim: input dimension
+linear_dim: fsmn input dimensionll
+proj_dim: fsmn projection dimension
+lorder: fsmn left order
+rorder: fsmn right order
+num_syn: output dimension
+fsmn_layers: no. of sequential fsmn layers
+'''
+
+@tables.register("encoder_classes", "FSMN")
+class FSMN(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ input_affine_dim: int,
+ fsmn_layers: int,
+ linear_dim: int,
+ proj_dim: int,
+ lorder: int,
+ rorder: int,
+ lstride: int,
+ rstride: int,
+ output_affine_dim: int,
+ output_dim: int
+ ):
+ super(FSMN, self).__init__()
+
+ self.input_dim = input_dim
+ self.input_affine_dim = input_affine_dim
+ self.fsmn_layers = fsmn_layers
+ self.linear_dim = linear_dim
+ self.proj_dim = proj_dim
+ self.output_affine_dim = output_affine_dim
+ self.output_dim = output_dim
+
+ self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+ self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+ self.relu = RectifiedLinear(linear_dim, linear_dim)
+ self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+ range(fsmn_layers)])
+ self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+ self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def fuse_modules(self):
+ pass
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ cache: Dict[str, torch.Tensor]
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Args:
+ input (torch.Tensor): Input tensor (B, T, D)
+ cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
+ {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+ """
+
+ x1 = self.in_linear1(input)
+ x2 = self.in_linear2(x1)
+ x3 = self.relu(x2)
+ x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn
+ x5 = self.out_linear1(x4)
+ x6 = self.out_linear2(x5)
+ x7 = self.softmax(x6)
+
+ return x7
+
+
+'''
+one deep fsmn layer
+dimproj: projection dimension, input and output dimension of memory blocks
+dimlinear: dimension of mapping layer
+lorder: left order
+rorder: right order
+lstride: left stride
+rstride: right stride
+'''
+
+@tables.register("encoder_classes", "DFSMN")
+class DFSMN(nn.Module):
+
+ def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
+ super(DFSMN, self).__init__()
+
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+
+ self.expand = AffineTransform(dimproj, dimlinear)
+ self.shrink = LinearTransform(dimlinear, dimproj)
+
+ self.conv_left = nn.Conv2d(
+ dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
+
+ if rorder > 0:
+ self.conv_right = nn.Conv2d(
+ dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
+ else:
+ self.conv_right = None
+
+ def forward(self, input):
+ f1 = F.relu(self.expand(input))
+ p1 = self.shrink(f1)
+
+ x = torch.unsqueeze(p1, 1)
+ x_per = x.permute(0, 3, 2, 1)
+
+ y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
+ if self.conv_right is not None:
+ y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
+ y_right = y_right[:, :, self.rstride:, :]
+ out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
+ else:
+ out = x_per + self.conv_left(y_left)
+
+ out1 = out.permute(0, 3, 2, 1)
+ output = input + out1.squeeze(1)
+
+ return output
+
+
+'''
+build stacked dfsmn layers
+'''
+
+
+def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
+ repeats = [
+ nn.Sequential(
+ DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
+ for i in range(fsmn_layers)
+ ]
+
+ return nn.Sequential(*repeats)
+
+
+if __name__ == '__main__':
+ fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
+ print(fsmn)
+
+ num_params = sum(p.numel() for p in fsmn.parameters())
+ print('the number of model params: {}'.format(num_params))
+ x = torch.zeros(128, 200, 400) # batch-size * time * dim
+ y, _ = fsmn(x) # batch-size * time * dim
+ print('input shape: {}'.format(x.shape))
+ print('output shape: {}'.format(y.shape))
+
+ print(fsmn.to_kaldi_net())
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
new file mode 100644
index 0000000..4c7e943
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -0,0 +1,781 @@
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+import logging
+import os
+import json
+import torch
+from torch import nn
+import math
+from typing import Optional
+import time
+from funasr.register import tables
+from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
+from torch.nn.utils.rnn import pad_sequence
+
+class VadStateMachine(Enum):
+ kVadInStateStartPointNotDetected = 1
+ kVadInStateInSpeechSegment = 2
+ kVadInStateEndPointDetected = 3
+
+
+class FrameState(Enum):
+ kFrameStateInvalid = -1
+ kFrameStateSpeech = 1
+ kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+ kChangeStateSpeech2Speech = 0
+ kChangeStateSpeech2Sil = 1
+ kChangeStateSil2Sil = 2
+ kChangeStateSil2Speech = 3
+ kChangeStateNoBegin = 4
+ kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+ kVadSingleUtteranceDetectMode = 0
+ kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+ snr_mode: int = 0,
+ max_end_silence_time: int = 800,
+ max_start_silence_time: int = 3000,
+ do_start_point_detection: bool = True,
+ do_end_point_detection: bool = True,
+ window_size_ms: int = 200,
+ sil_to_speech_time_thres: int = 150,
+ speech_to_sil_time_thres: int = 150,
+ speech_2_noise_ratio: float = 1.0,
+ do_extend: int = 1,
+ lookback_time_start_point: int = 200,
+ lookahead_time_end_point: int = 100,
+ max_single_segment_time: int = 60000,
+ nn_eval_block_size: int = 8,
+ dcd_block_size: int = 4,
+ snr_thres: int = -100.0,
+ noise_frame_num_used_for_snr: int = 100,
+ decibel_thres: int = -100.0,
+ speech_noise_thres: float = 0.6,
+ fe_prior_thres: float = 1e-4,
+ silence_pdf_num: int = 1,
+ sil_pdf_ids: List[int] = [0],
+ speech_noise_thresh_low: float = -0.1,
+ speech_noise_thresh_high: float = 0.3,
+ output_frame_probs: bool = False,
+ frame_in_ms: int = 10,
+ frame_length_ms: int = 25,
+ **kwargs,
+ ):
+ self.sample_rate = sample_rate
+ self.detect_mode = detect_mode
+ self.snr_mode = snr_mode
+ self.max_end_silence_time = max_end_silence_time
+ self.max_start_silence_time = max_start_silence_time
+ self.do_start_point_detection = do_start_point_detection
+ self.do_end_point_detection = do_end_point_detection
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time_thres = sil_to_speech_time_thres
+ self.speech_to_sil_time_thres = speech_to_sil_time_thres
+ self.speech_2_noise_ratio = speech_2_noise_ratio
+ self.do_extend = do_extend
+ self.lookback_time_start_point = lookback_time_start_point
+ self.lookahead_time_end_point = lookahead_time_end_point
+ self.max_single_segment_time = max_single_segment_time
+ self.nn_eval_block_size = nn_eval_block_size
+ self.dcd_block_size = dcd_block_size
+ self.snr_thres = snr_thres
+ self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+ self.decibel_thres = decibel_thres
+ self.speech_noise_thres = speech_noise_thres
+ self.fe_prior_thres = fe_prior_thres
+ self.silence_pdf_num = silence_pdf_num
+ self.sil_pdf_ids = sil_pdf_ids
+ self.speech_noise_thresh_low = speech_noise_thresh_low
+ self.speech_noise_thresh_high = speech_noise_thresh_high
+ self.output_frame_probs = output_frame_probs
+ self.frame_in_ms = frame_in_ms
+ self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+ def Reset(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+
+class E2EVadFrameProb(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self):
+ self.noise_prob = 0.0
+ self.speech_prob = 0.0
+ self.score = 0.0
+ self.frame_id = 0
+ self.frm_state = 0
+
+
+class WindowDetector(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+ speech_to_sil_time: int, frame_size_ms: int):
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time = sil_to_speech_time
+ self.speech_to_sil_time = speech_to_sil_time
+ self.frame_size_ms = frame_size_ms
+
+ self.win_size_frame = int(window_size_ms / frame_size_ms)
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame # 鍒濆鍖栫獥
+
+ self.cur_win_pos = 0
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+ self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def Reset(self) -> None:
+ self.cur_win_pos = 0
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def GetWinSize(self) -> int:
+ return int(self.win_size_frame)
+
+ def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+ cur_frame_state = FrameState.kFrameStateSil
+ if frameState == FrameState.kFrameStateSpeech:
+ cur_frame_state = 1
+ elif frameState == FrameState.kFrameStateSil:
+ cur_frame_state = 0
+ else:
+ return AudioChangeState.kChangeStateInvalid
+ self.win_sum -= self.win_state[self.cur_win_pos]
+ self.win_sum += cur_frame_state
+ self.win_state[self.cur_win_pos] = cur_frame_state
+ self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+ if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSpeech
+ return AudioChangeState.kChangeStateSil2Speech
+
+ if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSil
+ return AudioChangeState.kChangeStateSpeech2Sil
+
+ if self.pre_frame_state == FrameState.kFrameStateSil:
+ return AudioChangeState.kChangeStateSil2Sil
+ if self.pre_frame_state == FrameState.kFrameStateSpeech:
+ return AudioChangeState.kChangeStateSpeech2Speech
+ return AudioChangeState.kChangeStateInvalid
+
+ def FrameSizeMs(self) -> int:
+ return int(self.frame_size_ms)
+
+
+@tables.register("model_classes", "FsmnVADStreaming")
+class FsmnVADStreaming(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ vad_post_args: Dict[str, Any] = None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.vad_opts = VADXOptions(**kwargs)
+ self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
+ self.vad_opts.sil_to_speech_time_thres,
+ self.vad_opts.speech_to_sil_time_thres,
+ self.vad_opts.frame_in_ms)
+
+ encoder_class = tables.encoder_classes.get(encoder.lower())
+ encoder = encoder_class(**encoder_conf)
+ self.encoder = encoder
+ # init variables
+ self.data_buf_start_frame = 0
+ self.frm_cnt = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.continous_silence_frame_count = 0
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.number_end_time_detected = 0
+ self.sil_frame = 0
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+ self.noise_average_decibel = -100.0
+ self.pre_end_silence_detected = False
+ self.next_seg = True
+
+ self.output_data_buf = []
+ self.output_data_buf_offset = 0
+ self.frame_probs = []
+ self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
+ self.scores = None
+ self.max_time_out = False
+ self.decibel = []
+ self.data_buf = None
+ self.data_buf_all = None
+ self.waveform = None
+ self.last_drop_frames = 0
+
+ def AllResetDetection(self):
+ self.data_buf_start_frame = 0
+ self.frm_cnt = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.continous_silence_frame_count = 0
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.number_end_time_detected = 0
+ self.sil_frame = 0
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+ self.noise_average_decibel = -100.0
+ self.pre_end_silence_detected = False
+ self.next_seg = True
+
+ self.output_data_buf = []
+ self.output_data_buf_offset = 0
+ self.frame_probs = []
+ self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
+ self.scores = None
+ self.max_time_out = False
+ self.decibel = []
+ self.data_buf = None
+ self.data_buf_all = None
+ self.waveform = None
+ self.last_drop_frames = 0
+ self.windows_detector.Reset()
+
+ def ResetDetection(self):
+ self.continous_silence_frame_count = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.windows_detector.Reset()
+ self.sil_frame = 0
+ self.frame_probs = []
+
+ if self.output_data_buf:
+ assert self.output_data_buf[-1].contain_seg_end_point == True
+ drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+ real_drop_frames = drop_frames - self.last_drop_frames
+ self.last_drop_frames = drop_frames
+ self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+ self.decibel = self.decibel[real_drop_frames:]
+ self.scores = self.scores[:, real_drop_frames:, :]
+
+ def ComputeDecibel(self) -> None:
+ frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+ frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+ if self.data_buf_all is None:
+ self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
+ self.data_buf = self.data_buf_all
+ else:
+ self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
+ for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
+ self.decibel.append(
+ 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
+ 0.000001))
+
+ def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
+ scores = self.encoder(feats, cache).to('cpu') # return B * T * D
+ assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
+ self.vad_opts.nn_eval_block_size = scores.shape[1]
+ self.frm_cnt += scores.shape[1] # count total frames
+ if self.scores is None:
+ self.scores = scores # the first calculation
+ else:
+ self.scores = torch.cat((self.scores, scores), dim=1)
+
+ def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
+ while self.data_buf_start_frame < frame_idx:
+ if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+ self.data_buf_start_frame += 1
+ self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+
+ def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
+ last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
+ self.PopDataBufTillFrame(start_frm)
+ expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
+ if last_frm_is_end_point:
+ extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
+ self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+ expected_sample_number += int(extra_sample)
+ if end_point_is_sent_end:
+ expected_sample_number = max(expected_sample_number, len(self.data_buf))
+ if len(self.data_buf) < expected_sample_number:
+ print('error in calling pop data_buf\n')
+
+ if len(self.output_data_buf) == 0 or first_frm_is_start_point:
+ self.output_data_buf.append(E2EVadSpeechBufWithDoa())
+ self.output_data_buf[-1].Reset()
+ self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+ self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
+ self.output_data_buf[-1].doa = 0
+ cur_seg = self.output_data_buf[-1]
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+ print('warning\n')
+ out_pos = len(cur_seg.buffer) # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+ data_to_pop = 0
+ if end_point_is_sent_end:
+ data_to_pop = expected_sample_number
+ else:
+ data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+ if data_to_pop > len(self.data_buf):
+ print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
+ data_to_pop = len(self.data_buf)
+ expected_sample_number = len(self.data_buf)
+
+ cur_seg.doa = 0
+ for sample_cpy_out in range(0, data_to_pop):
+ # cur_seg.buffer[out_pos ++] = data_buf_.back();
+ out_pos += 1
+ for sample_cpy_out in range(data_to_pop, expected_sample_number):
+ # cur_seg.buffer[out_pos++] = data_buf_.back()
+ out_pos += 1
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+ print('Something wrong with the VAD algorithm\n')
+ self.data_buf_start_frame += frm_cnt
+ cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+ if first_frm_is_start_point:
+ cur_seg.contain_seg_start_point = True
+ if last_frm_is_end_point:
+ cur_seg.contain_seg_end_point = True
+
+ def OnSilenceDetected(self, valid_frame: int):
+ self.lastest_confirmed_silence_frame = valid_frame
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ self.PopDataBufTillFrame(valid_frame)
+ # silence_detected_callback_
+ # pass
+
+ def OnVoiceDetected(self, valid_frame: int) -> None:
+ self.latest_confirmed_speech_frame = valid_frame
+ self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
+
+ def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
+ if self.vad_opts.do_start_point_detection:
+ pass
+ if self.confirmed_start_frame != -1:
+ print('not reset vad properly\n')
+ else:
+ self.confirmed_start_frame = start_frame
+
+ if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
+
+ def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
+ for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
+ self.OnVoiceDetected(t)
+ if self.vad_opts.do_end_point_detection:
+ pass
+ if self.confirmed_end_frame != -1:
+ print('not reset vad properly\n')
+ else:
+ self.confirmed_end_frame = end_frame
+ if not fake_result:
+ self.sil_frame = 0
+ self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
+ self.number_end_time_detected += 1
+
+ def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
+ if is_final_frame:
+ self.OnVoiceEnd(cur_frm_idx, False, True)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+
+ def GetLatency(self) -> int:
+ return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
+
+ def LatencyFrmNumAtStartPoint(self) -> int:
+ vad_latency = self.windows_detector.GetWinSize()
+ if self.vad_opts.do_extend:
+ vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+ return vad_latency
+
+ def GetFrameState(self, t: int):
+ frame_state = FrameState.kFrameStateInvalid
+ cur_decibel = self.decibel[t]
+ cur_snr = cur_decibel - self.noise_average_decibel
+ # for each frame, calc log posterior probability of each state
+ if cur_decibel < self.vad_opts.decibel_thres:
+ frame_state = FrameState.kFrameStateSil
+ self.DetectOneFrame(frame_state, t, False)
+ return frame_state
+
+ sum_score = 0.0
+ noise_prob = 0.0
+ assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
+ if len(self.sil_pdf_ids) > 0:
+ assert len(self.scores) == 1 # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+ sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
+ sum_score = sum(sil_pdf_scores)
+ noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+ total_score = 1.0
+ sum_score = total_score - sum_score
+ speech_prob = math.log(sum_score)
+ if self.vad_opts.output_frame_probs:
+ frame_prob = E2EVadFrameProb()
+ frame_prob.noise_prob = noise_prob
+ frame_prob.speech_prob = speech_prob
+ frame_prob.score = sum_score
+ frame_prob.frame_id = t
+ self.frame_probs.append(frame_prob)
+ if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
+ if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+ frame_state = FrameState.kFrameStateSpeech
+ else:
+ frame_state = FrameState.kFrameStateSil
+ else:
+ frame_state = FrameState.kFrameStateSil
+ if self.noise_average_decibel < -99.9:
+ self.noise_average_decibel = cur_decibel
+ else:
+ self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
+ self.vad_opts.noise_frame_num_used_for_snr
+ - 1)) / self.vad_opts.noise_frame_num_used_for_snr
+
+ return frame_state
+
+ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
+ is_final: bool = False
+ ):
+ if not cache:
+ self.AllResetDetection()
+ self.waveform = waveform # compute decibel for each frame
+ self.ComputeDecibel()
+ self.ComputeScores(feats, cache)
+ if not is_final:
+ self.DetectCommonFrames()
+ else:
+ self.DetectLastFrames()
+ segments = []
+ for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
+ segment_batch = []
+ if len(self.output_data_buf) > 0:
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+ if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point):
+ continue
+ segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
+ segment_batch.append(segment)
+ self.output_data_buf_offset += 1 # need update this parameter
+ if segment_batch:
+ segments.append(segment_batch)
+ if is_final:
+ # reset class variables and clear the dict for the next query
+ self.AllResetDetection()
+ return segments, cache
+
+ def init_cache(self, cache: dict = {}, **kwargs):
+ cache["frontend"] = {}
+ cache["prev_samples"] = torch.empty(0)
+
+ return cache
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ cache: dict = {},
+ **kwargs,
+ ):
+
+ if len(cache) == 0:
+ self.init_cache(cache, **kwargs)
+
+ meta_data = {}
+ chunk_size = kwargs.get("chunk_size", 50) # 50ms
+ chunk_stride_samples = chunk_size * 16
+
+ time1 = time.perf_counter()
+ cfg = {"is_final": kwargs.get("is_final", False)}
+ audio_sample_list = load_audio_text_image_video(data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ **cfg,
+ )
+ _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
+
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ assert len(audio_sample_list) == 1, "batch_size must be set 1"
+
+ audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+
+ n = len(audio_sample) // chunk_stride_samples + int(_is_final)
+ m = len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))
+ tokens = []
+ for i in range(n):
+ kwargs["is_final"] = _is_final and i == n - 1
+ audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
+
+ # extract fbank feats
+ speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend, cache=cache["frontend"],
+ is_final=kwargs["is_final"])
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ meta_data = {}
+ audio_sample_list = [data_in]
+ if isinstance(data_in, torch.Tensor): # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # b. Forward Encoder streaming
+ t_offset = 0
+ feats = speech
+ feats_len = speech_lengths.max().item()
+ waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
+ cache = kwargs.get("cache", {})
+ batch_size = kwargs.get("batch_size", 1)
+ step = min(feats_len, 6000)
+ segments = [[]] * batch_size
+
+ for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
+ if t_offset + step >= feats_len - 1:
+ step = feats_len - t_offset
+ is_final = True
+ else:
+ is_final = False
+ batch = {
+ "feats": feats[:, t_offset:t_offset + step, :],
+ "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
+ "is_final": is_final,
+ "cache": cache
+ }
+
+
+ segments_part, cache = self.forward(**batch)
+ if segments_part:
+ for batch_num in range(0, batch_size):
+ segments[batch_num] += segments_part[batch_num]
+
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{1}best_recog"]
+
+ results = []
+ for i in range(batch_size):
+
+ if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+ results[i] = json.dumps(results[i])
+
+ if ibest_writer is not None:
+ ibest_writer["text"][key[i]] = segments[i]
+
+ result_i = {"key": key[i], "value": segments[i]}
+ results.append(result_i)
+
+ return results, meta_data
+
+
+ def DetectCommonFrames(self) -> int:
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+ return 0
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+ frame_state = FrameState.kFrameStateInvalid
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+
+ return 0
+
+ def DetectLastFrames(self) -> int:
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+ return 0
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+ frame_state = FrameState.kFrameStateInvalid
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
+ if i != 0:
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+ else:
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
+
+ return 0
+
+ def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
+ tmp_cur_frm_state = FrameState.kFrameStateInvalid
+ if cur_frm_state == FrameState.kFrameStateSpeech:
+ if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+ tmp_cur_frm_state = FrameState.kFrameStateSpeech
+ else:
+ tmp_cur_frm_state = FrameState.kFrameStateSil
+ elif cur_frm_state == FrameState.kFrameStateSil:
+ tmp_cur_frm_state = FrameState.kFrameStateSil
+ state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
+ frm_shift_in_ms = self.vad_opts.frame_in_ms
+ if AudioChangeState.kChangeStateSil2Speech == state_change:
+ silence_frame_count = self.continous_silence_frame_count
+ self.continous_silence_frame_count = 0
+ self.pre_end_silence_detected = False
+ start_frame = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+ self.OnVoiceStart(start_frame)
+ self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+ for t in range(start_frame + 1, cur_frm_idx + 1):
+ self.OnVoiceDetected(t)
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
+ self.OnVoiceDetected(t)
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+ self.continous_silence_frame_count = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ pass
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+ self.continous_silence_frame_count = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.max_time_out = True
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSil2Sil == state_change:
+ self.continous_silence_frame_count += 1
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ # silence timeout, return zero length decision
+ if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
+ self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+ or (is_final_frame and self.number_end_time_detected == 0):
+ for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
+ self.OnSilenceDetected(t)
+ self.OnVoiceStart(0, True)
+ self.OnVoiceEnd(0, True, False);
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ else:
+ if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
+ self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
+ lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
+ if self.vad_opts.do_extend:
+ lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
+ lookback_frame -= 1
+ lookback_frame = max(0, lookback_frame)
+ self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif self.vad_opts.do_extend and not is_final_frame:
+ if self.continous_silence_frame_count <= int(
+ self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
+ self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
+ self.ResetDetection()
+
+
+
diff --git a/funasr/models/fsmn_vad_streaming/template.yaml b/funasr/models/fsmn_vad_streaming/template.yaml
new file mode 100644
index 0000000..e8a3a4f
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/template.yaml
@@ -0,0 +1,62 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: FsmnVADStreaming
+model_conf:
+ sample_rate: 16000
+ detect_mode: 1
+ snr_mode: 0
+ max_end_silence_time: 800
+ max_start_silence_time: 3000
+ do_start_point_detection: True
+ do_end_point_detection: True
+ window_size_ms: 200
+ sil_to_speech_time_thres: 150
+ speech_to_sil_time_thres: 150
+ speech_2_noise_ratio: 1.0
+ do_extend: 1
+ lookback_time_start_point: 200
+ lookahead_time_end_point: 100
+ max_single_segment_time: 60000
+ snr_thres: -100.0
+ noise_frame_num_used_for_snr: 100
+ decibel_thres: -100.0
+ speech_noise_thres: 0.6
+ fe_prior_thres: 0.0001
+ silence_pdf_num: 1
+ sil_pdf_ids: [0]
+ speech_noise_thresh_low: -0.1
+ speech_noise_thresh_high: 0.3
+ output_frame_probs: False
+ frame_in_ms: 10
+ frame_length_ms: 25
+
+encoder: FSMN
+encoder_conf:
+ input_dim: 400
+ input_affine_dim: 140
+ fsmn_layers: 4
+ linear_dim: 250
+ proj_dim: 128
+ lorder: 20
+ rorder: 0
+ lstride: 1
+ rstride: 0
+ output_affine_dim: 140
+ output_dim: 248
+
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ dither: 0.0
+ lfr_m: 5
+ lfr_n: 1
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 927b091..fdc0c93 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -519,16 +519,23 @@
if len(cache) == 0:
self.init_cache(cache, **kwargs)
- _is_final = kwargs.get("is_final", False)
+
meta_data = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
chunk_stride_samples = chunk_size[1] * 960 # 600ms
time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
+ cfg = {"is_final": kwargs.get("is_final", False)}
+ audio_sample_list = load_audio_text_image_video(data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ **cfg,
+ )
+ _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
+
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
assert len(audio_sample_list) == 1, "batch_size must be set 1"
diff --git a/funasr/models/paraformer_streaming/template.yaml b/funasr/models/paraformer_streaming/template.yaml
new file mode 100644
index 0000000..d1300ac
--- /dev/null
+++ b/funasr/models/paraformer_streaming/template.yaml
@@ -0,0 +1,143 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: ParaformerStreaming
+model_conf:
+ ctc_weight: 0.0
+ lsm_weight: 0.1
+ length_normalized_loss: true
+ predictor_weight: 1.0
+ predictor_bias: 1
+ sampling_ratio: 0.75
+
+# encoder
+encoder: SANMEncoderChunkOpt
+encoder_conf:
+ output_size: 512
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 50
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: pe_online
+ pos_enc_class: SinusoidalPositionEncoder
+ normalize_before: true
+ kernel_size: 11
+ sanm_shfit: 0
+ selfattention_layer_type: sanm
+ chunk_size:
+ - 12
+ - 15
+ stride:
+ - 8
+ - 10
+ pad_left:
+ - 0
+ - 0
+ encoder_att_look_back_factor:
+ - 4
+ - 4
+ decoder_att_look_back_factor:
+ - 1
+ - 1
+
+# decoder
+decoder: ParaformerSANMDecoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 16
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.1
+ src_attention_dropout_rate: 0.1
+ att_layer_num: 16
+ kernel_size: 11
+ sanm_shfit: 5
+
+predictor: CifPredictorV2
+predictor_conf:
+ idim: 512
+ threshold: 1.0
+ l_order: 1
+ r_order: 1
+ tail_threshold: 0.45
+
+# frontend related
+frontend: WavFrontendOnline
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 7
+ lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 6
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ val_scheduler_criterion:
+ - valid
+ - acc
+ best_model_criterion:
+ - - valid
+ - acc
+ - max
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: DynamicBatchLocalShuffleSampler
+ batch_type: example # example or length
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 500
+ shuffle: True
+ num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index bb9cf01..638e0ac 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -16,7 +16,7 @@
-def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None):
+def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
@@ -26,20 +26,29 @@
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
- data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer)
+ data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
return data_or_path_or_list_ret
else:
- return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type) for audio in data_or_path_or_list]
- if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'):
+ return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
+
+ if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
- if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list):
+
+ if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
if data_type is None or data_type == "sound":
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
data_or_path_or_list = data_or_path_or_list[0, :]
- # elif data_type == "text" and tokenizer is not None:
- # data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+ elif data_type == "text" and tokenizer is not None:
+ data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+ elif data_type == "image": # undo
+ pass
+ elif data_type == "video": # undo
+ pass
+
+ # if data_in is a file or url, set is_final=True
+ kwargs["is_final"] = True
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
--
Gitblit v1.9.1