From 012903e42ec890ab5c50137beb365c3d94e731d1 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 30 六月 2023 11:21:28 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/models/e2e_vad.py | 23 +++++++++++++++++------
1 files changed, 17 insertions(+), 6 deletions(-)
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 71ed2cf..7c55b2e 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -5,6 +5,7 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.base_model import FunASRModel
class VadStateMachine(Enum):
@@ -211,7 +212,7 @@
return int(self.frame_size_ms)
-class E2EVadModel(nn.Module):
+class E2EVadModel(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
@@ -252,8 +253,8 @@
self.data_buf = None
self.data_buf_all = None
self.waveform = None
- self.ResetDetection()
self.frontend = frontend
+ self.last_drop_frames = 0
def AllResetDetection(self):
self.data_buf_start_frame = 0
@@ -282,7 +283,8 @@
self.data_buf = None
self.data_buf_all = None
self.waveform = None
- self.ResetDetection()
+ self.last_drop_frames = 0
+ self.windows_detector.Reset()
def ResetDetection(self):
self.continous_silence_frame_count = 0
@@ -294,6 +296,15 @@
self.windows_detector.Reset()
self.sil_frame = 0
self.frame_probs = []
+
+ if self.output_data_buf:
+ assert self.output_data_buf[-1].contain_seg_end_point == True
+ drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+ real_drop_frames = drop_frames - self.last_drop_frames
+ self.last_drop_frames = drop_frames
+ self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+ self.decibel = self.decibel[real_drop_frames:]
+ self.scores = self.scores[:, real_drop_frames:, :]
def ComputeDecibel(self) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
@@ -322,7 +333,7 @@
while self.data_buf_start_frame < frame_idx:
if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
self.data_buf_start_frame += 1
- self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
+ self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
@@ -543,7 +554,7 @@
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
- frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
return 0
@@ -553,7 +564,7 @@
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
- frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
if i != 0:
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
else:
--
Gitblit v1.9.1