From c8bae0ec85eee25d66de6b1e4502eff74d750b24 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 十二月 2023 13:29:37 +0800
Subject: [PATCH] funasr2

---
 funasr/models/fsmn_vad/model.py |  517 +++++++++++++++++++++++++++++++++-----------------------
 1 files changed, 302 insertions(+), 215 deletions(-)

diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index cc3c87e..16f21dc 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -1,33 +1,244 @@
 from enum import Enum
 from typing import List, Tuple, Dict, Any
-
+import logging
+import os
+import json
 import torch
 from torch import nn
 import math
 from typing import Optional
-from funasr.models.encoder.fsmn_encoder import FSMN
-from funasr.models.base_model import FunASRModel
-from funasr.models.model_class_factory import *
+import time
+from funasr.utils.register import register_class, registry_tables
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
+from torch.nn.utils.rnn import pad_sequence
+
+class VadStateMachine(Enum):
+    kVadInStateStartPointNotDetected = 1
+    kVadInStateInSpeechSegment = 2
+    kVadInStateEndPointDetected = 3
 
 
+class FrameState(Enum):
+    kFrameStateInvalid = -1
+    kFrameStateSpeech = 1
+    kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+    kChangeStateSpeech2Speech = 0
+    kChangeStateSpeech2Sil = 1
+    kChangeStateSil2Sil = 2
+    kChangeStateSil2Speech = 3
+    kChangeStateNoBegin = 4
+    kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+    kVadSingleUtteranceDetectMode = 0
+    kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(
+            self,
+            sample_rate: int = 16000,
+            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+            snr_mode: int = 0,
+            max_end_silence_time: int = 800,
+            max_start_silence_time: int = 3000,
+            do_start_point_detection: bool = True,
+            do_end_point_detection: bool = True,
+            window_size_ms: int = 200,
+            sil_to_speech_time_thres: int = 150,
+            speech_to_sil_time_thres: int = 150,
+            speech_2_noise_ratio: float = 1.0,
+            do_extend: int = 1,
+            lookback_time_start_point: int = 200,
+            lookahead_time_end_point: int = 100,
+            max_single_segment_time: int = 60000,
+            nn_eval_block_size: int = 8,
+            dcd_block_size: int = 4,
+            snr_thres: int = -100.0,
+            noise_frame_num_used_for_snr: int = 100,
+            decibel_thres: int = -100.0,
+            speech_noise_thres: float = 0.6,
+            fe_prior_thres: float = 1e-4,
+            silence_pdf_num: int = 1,
+            sil_pdf_ids: List[int] = [0],
+            speech_noise_thresh_low: float = -0.1,
+            speech_noise_thresh_high: float = 0.3,
+            output_frame_probs: bool = False,
+            frame_in_ms: int = 10,
+            frame_length_ms: int = 25,
+            **kwargs,
+    ):
+        self.sample_rate = sample_rate
+        self.detect_mode = detect_mode
+        self.snr_mode = snr_mode
+        self.max_end_silence_time = max_end_silence_time
+        self.max_start_silence_time = max_start_silence_time
+        self.do_start_point_detection = do_start_point_detection
+        self.do_end_point_detection = do_end_point_detection
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time_thres = sil_to_speech_time_thres
+        self.speech_to_sil_time_thres = speech_to_sil_time_thres
+        self.speech_2_noise_ratio = speech_2_noise_ratio
+        self.do_extend = do_extend
+        self.lookback_time_start_point = lookback_time_start_point
+        self.lookahead_time_end_point = lookahead_time_end_point
+        self.max_single_segment_time = max_single_segment_time
+        self.nn_eval_block_size = nn_eval_block_size
+        self.dcd_block_size = dcd_block_size
+        self.snr_thres = snr_thres
+        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+        self.decibel_thres = decibel_thres
+        self.speech_noise_thres = speech_noise_thres
+        self.fe_prior_thres = fe_prior_thres
+        self.silence_pdf_num = silence_pdf_num
+        self.sil_pdf_ids = sil_pdf_ids
+        self.speech_noise_thresh_low = speech_noise_thresh_low
+        self.speech_noise_thresh_high = speech_noise_thresh_high
+        self.output_frame_probs = output_frame_probs
+        self.frame_in_ms = frame_in_ms
+        self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+    def Reset(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+
+class E2EVadFrameProb(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self):
+        self.noise_prob = 0.0
+        self.speech_prob = 0.0
+        self.score = 0.0
+        self.frame_id = 0
+        self.frm_state = 0
+
+
+class WindowDetector(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+                 speech_to_sil_time: int, frame_size_ms: int):
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time = sil_to_speech_time
+        self.speech_to_sil_time = speech_to_sil_time
+        self.frame_size_ms = frame_size_ms
+
+        self.win_size_frame = int(window_size_ms / frame_size_ms)
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
+
+        self.cur_win_pos = 0
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def Reset(self) -> None:
+        self.cur_win_pos = 0
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def GetWinSize(self) -> int:
+        return int(self.win_size_frame)
+
+    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+        cur_frame_state = FrameState.kFrameStateSil
+        if frameState == FrameState.kFrameStateSpeech:
+            cur_frame_state = 1
+        elif frameState == FrameState.kFrameStateSil:
+            cur_frame_state = 0
+        else:
+            return AudioChangeState.kChangeStateInvalid
+        self.win_sum -= self.win_state[self.cur_win_pos]
+        self.win_sum += cur_frame_state
+        self.win_state[self.cur_win_pos] = cur_frame_state
+        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSpeech
+            return AudioChangeState.kChangeStateSil2Speech
+
+        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSil
+            return AudioChangeState.kChangeStateSpeech2Sil
+
+        if self.pre_frame_state == FrameState.kFrameStateSil:
+            return AudioChangeState.kChangeStateSil2Sil
+        if self.pre_frame_state == FrameState.kFrameStateSpeech:
+            return AudioChangeState.kChangeStateSpeech2Speech
+        return AudioChangeState.kChangeStateInvalid
+
+    def FrameSizeMs(self) -> int:
+        return int(self.frame_size_ms)
+
+
+@register_class("model_classes", "FsmnVAD")
 class FsmnVAD(nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     Deep-FSMN for Large Vocabulary Continuous Speech Recognition
     https://arxiv.org/abs/1803.05030
     """
-    def __init__(self, encoder: str = None,
+    def __init__(self,
+                 encoder: str = None,
                  encoder_conf: Optional[Dict] = None,
                  vad_post_args: Dict[str, Any] = None,
-                 frontend=None):
+                 **kwargs,
+                 ):
         super().__init__()
-        self.vad_opts = VADXOptions(**vad_post_args)
+        self.vad_opts = VADXOptions(**kwargs)
         self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                                self.vad_opts.sil_to_speech_time_thres,
                                                self.vad_opts.speech_to_sil_time_thres,
                                                self.vad_opts.frame_in_ms)
         
-        encoder_class = encoder_classes.get_class(encoder)
+        encoder_class = registry_tables.encoder_classes.get(encoder.lower())
         encoder = encoder_class(**encoder_conf)
         self.encoder = encoder
         # init variables
@@ -57,7 +268,6 @@
         self.data_buf = None
         self.data_buf_all = None
         self.waveform = None
-        self.frontend = frontend
         self.last_drop_frames = 0
 
     def AllResetDetection(self):
@@ -239,7 +449,7 @@
             vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
         return vad_latency
 
-    def GetFrameState(self, t: int) -> FrameState:
+    def GetFrameState(self, t: int):
         frame_state = FrameState.kFrameStateInvalid
         cur_decibel = self.decibel[t]
         cur_snr = cur_decibel - self.noise_average_decibel
@@ -285,7 +495,7 @@
 
     def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
-                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+                ):
         if not in_cache:
             self.AllResetDetection()
         self.waveform = waveform  # compute decibel for each frame
@@ -312,6 +522,87 @@
             # reset class variables and clear the dict for the next query
             self.AllResetDetection()
         return segments, in_cache
+
+    def generate(self,
+                 data_in,
+                 data_lengths=None,
+                 key: list = None,
+                 tokenizer=None,
+                 frontend=None,
+                 **kwargs,
+                 ):
+
+
+        meta_data = {}
+        audio_sample_list = [data_in]
+        if isinstance(data_in, torch.Tensor):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data[
+                "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+        # b. Forward Encoder streaming
+        t_offset = 0
+        feats = speech
+        feats_len = speech_lengths.max().item()
+        waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
+        in_cache = kwargs.get("in_cache", {})
+        batch_size = kwargs.get("batch_size", 1)
+        step = min(feats_len, 6000)
+        segments = [[]] * batch_size
+
+        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
+            if t_offset + step >= feats_len - 1:
+                step = feats_len - t_offset
+                is_final = True
+            else:
+                is_final = False
+            batch = {
+                "feats": feats[:, t_offset:t_offset + step, :],
+                "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
+                "is_final": is_final,
+                "in_cache": in_cache
+            }
+
+
+            segments_part, in_cache = self.forward(**batch)
+            if segments_part:
+                for batch_num in range(0, batch_size):
+                    segments[batch_num] += segments_part[batch_num]
+
+        ibest_writer = None
+        if ibest_writer is None and kwargs.get("output_dir") is not None:
+            writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = writer[f"{1}best_recog"]
+
+        results = []
+        for i in range(batch_size):
+            
+            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+                results[i] = json.dumps(results[i])
+                
+            if ibest_writer is not None:
+                ibest_writer["text"][key[i]] = segments[i]
+
+            result_i = {"key": key[i], "value": segments[i]}
+            results.append(result_i)
+ 
+        return results, meta_data
 
     def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                        is_final: bool = False, max_end_sil: int = 800
@@ -481,209 +772,5 @@
                 self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
             self.ResetDetection()
 
-
-
-class VadStateMachine(Enum):
-    kVadInStateStartPointNotDetected = 1
-    kVadInStateInSpeechSegment = 2
-    kVadInStateEndPointDetected = 3
-
-
-class FrameState(Enum):
-    kFrameStateInvalid = -1
-    kFrameStateSpeech = 1
-    kFrameStateSil = 0
-
-
-# final voice/unvoice state per frame
-class AudioChangeState(Enum):
-    kChangeStateSpeech2Speech = 0
-    kChangeStateSpeech2Sil = 1
-    kChangeStateSil2Sil = 2
-    kChangeStateSil2Speech = 3
-    kChangeStateNoBegin = 4
-    kChangeStateInvalid = 5
-
-
-class VadDetectMode(Enum):
-    kVadSingleUtteranceDetectMode = 0
-    kVadMutipleUtteranceDetectMode = 1
-
-
-class VADXOptions:
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(
-            self,
-            sample_rate: int = 16000,
-            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
-            snr_mode: int = 0,
-            max_end_silence_time: int = 800,
-            max_start_silence_time: int = 3000,
-            do_start_point_detection: bool = True,
-            do_end_point_detection: bool = True,
-            window_size_ms: int = 200,
-            sil_to_speech_time_thres: int = 150,
-            speech_to_sil_time_thres: int = 150,
-            speech_2_noise_ratio: float = 1.0,
-            do_extend: int = 1,
-            lookback_time_start_point: int = 200,
-            lookahead_time_end_point: int = 100,
-            max_single_segment_time: int = 60000,
-            nn_eval_block_size: int = 8,
-            dcd_block_size: int = 4,
-            snr_thres: int = -100.0,
-            noise_frame_num_used_for_snr: int = 100,
-            decibel_thres: int = -100.0,
-            speech_noise_thres: float = 0.6,
-            fe_prior_thres: float = 1e-4,
-            silence_pdf_num: int = 1,
-            sil_pdf_ids: List[int] = [0],
-            speech_noise_thresh_low: float = -0.1,
-            speech_noise_thresh_high: float = 0.3,
-            output_frame_probs: bool = False,
-            frame_in_ms: int = 10,
-            frame_length_ms: int = 25,
-    ):
-        self.sample_rate = sample_rate
-        self.detect_mode = detect_mode
-        self.snr_mode = snr_mode
-        self.max_end_silence_time = max_end_silence_time
-        self.max_start_silence_time = max_start_silence_time
-        self.do_start_point_detection = do_start_point_detection
-        self.do_end_point_detection = do_end_point_detection
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time_thres = sil_to_speech_time_thres
-        self.speech_to_sil_time_thres = speech_to_sil_time_thres
-        self.speech_2_noise_ratio = speech_2_noise_ratio
-        self.do_extend = do_extend
-        self.lookback_time_start_point = lookback_time_start_point
-        self.lookahead_time_end_point = lookahead_time_end_point
-        self.max_single_segment_time = max_single_segment_time
-        self.nn_eval_block_size = nn_eval_block_size
-        self.dcd_block_size = dcd_block_size
-        self.snr_thres = snr_thres
-        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
-        self.decibel_thres = decibel_thres
-        self.speech_noise_thres = speech_noise_thres
-        self.fe_prior_thres = fe_prior_thres
-        self.silence_pdf_num = silence_pdf_num
-        self.sil_pdf_ids = sil_pdf_ids
-        self.speech_noise_thresh_low = speech_noise_thresh_low
-        self.speech_noise_thresh_high = speech_noise_thresh_high
-        self.output_frame_probs = output_frame_probs
-        self.frame_in_ms = frame_in_ms
-        self.frame_length_ms = frame_length_ms
-
-
-class E2EVadSpeechBufWithDoa(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
-
-    def Reset(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
-
-
-class E2EVadFrameProb(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.noise_prob = 0.0
-        self.speech_prob = 0.0
-        self.score = 0.0
-        self.frame_id = 0
-        self.frm_state = 0
-
-
-class WindowDetector(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
-                 speech_to_sil_time: int, frame_size_ms: int):
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time = sil_to_speech_time
-        self.speech_to_sil_time = speech_to_sil_time
-        self.frame_size_ms = frame_size_ms
-
-        self.win_size_frame = int(window_size_ms / frame_size_ms)
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
-
-        self.cur_win_pos = 0
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
-        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
-
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def Reset(self) -> None:
-        self.cur_win_pos = 0
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def GetWinSize(self) -> int:
-        return int(self.win_size_frame)
-
-    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
-        cur_frame_state = FrameState.kFrameStateSil
-        if frameState == FrameState.kFrameStateSpeech:
-            cur_frame_state = 1
-        elif frameState == FrameState.kFrameStateSil:
-            cur_frame_state = 0
-        else:
-            return AudioChangeState.kChangeStateInvalid
-        self.win_sum -= self.win_state[self.cur_win_pos]
-        self.win_sum += cur_frame_state
-        self.win_state[self.cur_win_pos] = cur_frame_state
-        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
-
-        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSpeech
-            return AudioChangeState.kChangeStateSil2Speech
-
-        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSil
-            return AudioChangeState.kChangeStateSpeech2Sil
-
-        if self.pre_frame_state == FrameState.kFrameStateSil:
-            return AudioChangeState.kChangeStateSil2Sil
-        if self.pre_frame_state == FrameState.kFrameStateSpeech:
-            return AudioChangeState.kChangeStateSpeech2Speech
-        return AudioChangeState.kChangeStateInvalid
-
-    def FrameSizeMs(self) -> int:
-        return int(self.frame_size_ms)
 
 

--
Gitblit v1.9.1