From 0143122a4e2ee86cc27ba137b2bb0530577cbf12 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 12 一月 2024 10:27:36 +0800
Subject: [PATCH] funasr1.0 streaming demo

---
 funasr/models/fsmn_vad/model.py |   68 +++++++---------------------------
 1 files changed, 14 insertions(+), 54 deletions(-)

diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index 16f21dc..1ed0773 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -8,8 +8,8 @@
 import math
 from typing import Optional
 import time
-from funasr.utils.register import register_class, registry_tables
-from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
+from funasr.register import tables
+from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
 from funasr.utils.datadir_writer import DatadirWriter
 from torch.nn.utils.rnn import pad_sequence
 
@@ -218,7 +218,7 @@
         return int(self.frame_size_ms)
 
 
-@register_class("model_classes", "FsmnVAD")
+@tables.register("model_classes", "FsmnVAD")
 class FsmnVAD(nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -238,7 +238,7 @@
                                                self.vad_opts.speech_to_sil_time_thres,
                                                self.vad_opts.frame_in_ms)
         
-        encoder_class = registry_tables.encoder_classes.get(encoder.lower())
+        encoder_class = tables.encoder_classes.get(encoder.lower())
         encoder = encoder_class(**encoder_conf)
         self.encoder = encoder
         # init variables
@@ -333,8 +333,8 @@
                 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                 0.000001))
 
-    def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
-        scores = self.encoder(feats, in_cache).to('cpu')  # return B * T * D
+    def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
+        scores = self.encoder(feats, cache).to('cpu')  # return B * T * D
         assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
         self.vad_opts.nn_eval_block_size = scores.shape[1]
         self.frm_cnt += scores.shape[1]  # count total frames
@@ -493,14 +493,14 @@
 
         return frame_state
 
-    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
                 ):
-        if not in_cache:
+        if not cache:
             self.AllResetDetection()
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
-        self.ComputeScores(feats, in_cache)
+        self.ComputeScores(feats, cache)
         if not is_final:
             self.DetectCommonFrames()
         else:
@@ -521,7 +521,7 @@
         if is_final:
             # reset class variables and clear the dict for the next query
             self.AllResetDetection()
-        return segments, in_cache
+        return segments, cache
 
     def generate(self,
                  data_in,
@@ -544,7 +544,7 @@
         else:
             # extract fbank feats
             time1 = time.perf_counter()
-            audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
             speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
@@ -561,7 +561,7 @@
         feats = speech
         feats_len = speech_lengths.max().item()
         waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
-        in_cache = kwargs.get("in_cache", {})
+        cache = kwargs.get("cache", {})
         batch_size = kwargs.get("batch_size", 1)
         step = min(feats_len, 6000)
         segments = [[]] * batch_size
@@ -576,11 +576,11 @@
                 "feats": feats[:, t_offset:t_offset + step, :],
                 "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
                 "is_final": is_final,
-                "in_cache": in_cache
+                "cache": cache
             }
 
 
-            segments_part, in_cache = self.forward(**batch)
+            segments_part, cache = self.forward(**batch)
             if segments_part:
                 for batch_num in range(0, batch_size):
                     segments[batch_num] += segments_part[batch_num]
@@ -603,46 +603,6 @@
             results.append(result_i)
  
         return results, meta_data
-
-    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
-                       is_final: bool = False, max_end_sil: int = 800
-                       ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
-        if not in_cache:
-            self.AllResetDetection()
-        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
-        self.waveform = waveform  # compute decibel for each frame
-
-        self.ComputeScores(feats, in_cache)
-        self.ComputeDecibel()
-        if not is_final:
-            self.DetectCommonFrames()
-        else:
-            self.DetectLastFrames()
-        segments = []
-        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
-            segment_batch = []
-            if len(self.output_data_buf) > 0:
-                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if not self.output_data_buf[i].contain_seg_start_point:
-                        continue
-                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
-                        continue
-                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
-                    if self.output_data_buf[i].contain_seg_end_point:
-                        end_ms = self.output_data_buf[i].end_ms
-                        self.next_seg = True
-                        self.output_data_buf_offset += 1
-                    else:
-                        end_ms = -1
-                        self.next_seg = False
-                    segment = [start_ms, end_ms]
-                    segment_batch.append(segment)
-            if segment_batch:
-                segments.append(segment_batch)
-        if is_final:
-            # reset class variables and clear the dict for the next query
-            self.AllResetDetection()
-        return segments, in_cache
 
     def DetectCommonFrames(self) -> int:
         if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:

--
Gitblit v1.9.1