From 0143122a4e2ee86cc27ba137b2bb0530577cbf12 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 12 一月 2024 10:27:36 +0800
Subject: [PATCH] funasr1.0 streaming demo
---
funasr/models/fsmn_vad/model.py | 68 +++++++---------------------------
1 files changed, 14 insertions(+), 54 deletions(-)
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index 16f21dc..1ed0773 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -8,8 +8,8 @@
import math
from typing import Optional
import time
-from funasr.utils.register import register_class, registry_tables
-from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
+from funasr.register import tables
+from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from torch.nn.utils.rnn import pad_sequence
@@ -218,7 +218,7 @@
return int(self.frame_size_ms)
-@register_class("model_classes", "FsmnVAD")
+@tables.register("model_classes", "FsmnVAD")
class FsmnVAD(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -238,7 +238,7 @@
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
- encoder_class = registry_tables.encoder_classes.get(encoder.lower())
+ encoder_class = tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
# init variables
@@ -333,8 +333,8 @@
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
0.000001))
- def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
- scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D
+ def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
+ scores = self.encoder(feats, cache).to('cpu') # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
@@ -493,14 +493,14 @@
return frame_state
- def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
):
- if not in_cache:
+ if not cache:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
- self.ComputeScores(feats, in_cache)
+ self.ComputeScores(feats, cache)
if not is_final:
self.DetectCommonFrames()
else:
@@ -521,7 +521,7 @@
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
- return segments, in_cache
+ return segments, cache
def generate(self,
data_in,
@@ -544,7 +544,7 @@
else:
# extract fbank feats
time1 = time.perf_counter()
- audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
@@ -561,7 +561,7 @@
feats = speech
feats_len = speech_lengths.max().item()
waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
- in_cache = kwargs.get("in_cache", {})
+ cache = kwargs.get("cache", {})
batch_size = kwargs.get("batch_size", 1)
step = min(feats_len, 6000)
segments = [[]] * batch_size
@@ -576,11 +576,11 @@
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final": is_final,
- "in_cache": in_cache
+ "cache": cache
}
- segments_part, in_cache = self.forward(**batch)
+ segments_part, cache = self.forward(**batch)
if segments_part:
for batch_num in range(0, batch_size):
segments[batch_num] += segments_part[batch_num]
@@ -603,46 +603,6 @@
results.append(result_i)
return results, meta_data
-
- def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
- is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
- if not in_cache:
- self.AllResetDetection()
- self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
- self.waveform = waveform # compute decibel for each frame
-
- self.ComputeScores(feats, in_cache)
- self.ComputeDecibel()
- if not is_final:
- self.DetectCommonFrames()
- else:
- self.DetectLastFrames()
- segments = []
- for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
- segment_batch = []
- if len(self.output_data_buf) > 0:
- for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
- if not self.output_data_buf[i].contain_seg_start_point:
- continue
- if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
- continue
- start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
- if self.output_data_buf[i].contain_seg_end_point:
- end_ms = self.output_data_buf[i].end_ms
- self.next_seg = True
- self.output_data_buf_offset += 1
- else:
- end_ms = -1
- self.next_seg = False
- segment = [start_ms, end_ms]
- segment_batch.append(segment)
- if segment_batch:
- segments.append(segment_batch)
- if is_final:
- # reset class variables and clear the dict for the next query
- self.AllResetDetection()
- return segments, in_cache
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
--
Gitblit v1.9.1