From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/models/e2e_vad.py |   18 ++++++++++--------
 1 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 846341d..7c55b2e 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -5,6 +5,7 @@
 from torch import nn
 import math
 from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.base_model import FunASRModel
 
 
 class VadStateMachine(Enum):
@@ -211,7 +212,7 @@
         return int(self.frame_size_ms)
 
 
-class E2EVadModel(nn.Module):
+class E2EVadModel(FunASRModel):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     Deep-FSMN for Large Vocabulary Continuous Speech Recognition
@@ -296,13 +297,14 @@
         self.sil_frame = 0
         self.frame_probs = []
 
-        assert self.output_data_buf[-1].contain_seg_end_point == True
-        drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
-        real_drop_frames = drop_frames - self.last_drop_frames
-        self.last_drop_frames = drop_frames
-        self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-        self.decibel = self.decibel[real_drop_frames:]
-        self.scores = self.scores[:, real_drop_frames:, :]
+        if self.output_data_buf:
+            assert self.output_data_buf[-1].contain_seg_end_point == True
+            drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+            real_drop_frames = drop_frames - self.last_drop_frames
+            self.last_drop_frames = drop_frames
+            self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+            self.decibel = self.decibel[real_drop_frames:]
+            self.scores = self.scores[:, real_drop_frames:, :]
 
     def ComputeDecibel(self) -> None:
         frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)

--
Gitblit v1.9.1