From 31eed1834f9ff17d6246008f64d3e061f58ef80a Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期一, 27 二月 2023 13:33:55 +0800
Subject: [PATCH] in_cache & support soundfile read

---
 funasr/bin/vad_inference.py                     |   26 ++++--
 funasr/models/e2e_vad.py                        |   34 ++++---
 funasr/models/encoder/fsmn_encoder.py           |   44 ++++------
 funasr/bin/asr_inference_paraformer_vad_punc.py |   96 -----------------------
 4 files changed, 54 insertions(+), 146 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 96f70ef..1320877 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -43,6 +43,7 @@
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
+from funasr.bin.vad_inference import Speech2VadSegment
 from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
 from funasr.bin.punctuation_infer import Text2Punc
 from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
@@ -363,101 +364,6 @@
         else:
             hotword_list = None
         return hotword_list
-
-class Speech2VadSegment:
-    """Speech2VadSegment class
-
-    Examples:
-        >>> import soundfile
-        >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
-        >>> audio, rate = soundfile.read("speech.wav")
-        >>> speech2segment(audio)
-        [[10, 230], [245, 450], ...]
-
-    """
-
-    def __init__(
-            self,
-            vad_infer_config: Union[Path, str] = None,
-            vad_model_file: Union[Path, str] = None,
-            vad_cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
-            batch_size: int = 1,
-            dtype: str = "float32",
-            **kwargs,
-    ):
-        assert check_argument_types()
-
-        # 1. Build vad model
-        vad_model, vad_infer_args = VADTask.build_model_from_file(
-            vad_infer_config, vad_model_file, device
-        )
-        frontend = None
-        if vad_infer_args.frontend is not None:
-            frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
-
-        # logging.info("vad_model: {}".format(vad_model))
-        # logging.info("vad_infer_args: {}".format(vad_infer_args))
-        vad_model.to(dtype=getattr(torch, dtype)).eval()
-
-        self.vad_model = vad_model
-        self.vad_infer_args = vad_infer_args
-        self.device = device
-        self.dtype = dtype
-        self.frontend = frontend
-        self.batch_size = batch_size
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ) -> List[List[int]]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-        assert check_argument_types()
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            self.frontend.filter_length_max = math.inf
-            fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
-            feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
-            fbanks = to_device(fbanks, device=self.device)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-        else:
-            raise Exception("Need to extract feats first, please configure frontend configuration")
-
-        # b. Forward Encoder streaming
-        t_offset = 0
-        step = min(feats_len, 6000)
-        segments = [[]] * self.batch_size
-        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
-            if t_offset + step >= feats_len - 1:
-                step = feats_len - t_offset
-                is_final_send = True
-            else:
-                is_final_send = False
-            batch = {
-                "feats": feats[:, t_offset:t_offset + step, :],
-                "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
-                "is_final_send": is_final_send
-            }
-            # a. To device
-            batch = to_device(batch, device=self.device)
-            segments_part = self.vad_model(**batch)
-            if segments_part:
-                for batch_num in range(0, self.batch_size):
-                    segments[batch_num] += segments_part[batch_num]
-
-        return fbanks, segments
 
 
 def inference(
diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
index 607f131..258b38b 100644
--- a/funasr/bin/vad_inference.py
+++ b/funasr/bin/vad_inference.py
@@ -11,6 +11,7 @@
 from typing import Union
 from typing import Dict
 
+import math
 import numpy as np
 import torch
 from typeguard import check_argument_types
@@ -86,7 +87,7 @@
     @torch.no_grad()
     def __call__(
             self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ) -> List[List[int]]:
+    ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
         """Inference
 
         Args:
@@ -102,7 +103,10 @@
             speech = torch.tensor(speech)
 
         if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
+            self.frontend.filter_length_max = math.inf
+            fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
+            feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
+            fbanks = to_device(fbanks, device=self.device)
             feats = to_device(feats, device=self.device)
             feats_len = feats_len.int()
         else:
@@ -110,18 +114,18 @@
 
         # b. Forward Encoder streaming
         t_offset = 0
-        step = min(feats_len, 6000)
+        step = min(feats_len.max(), 6000)
         segments = [[]] * self.batch_size
         for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
             if t_offset + step >= feats_len - 1:
                 step = feats_len - t_offset
-                is_final_send = True
+                is_final = True
             else:
-                is_final_send = False
+                is_final = False
             batch = {
                 "feats": feats[:, t_offset:t_offset + step, :],
                 "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
-                "is_final_send": is_final_send
+                "is_final": is_final
             }
             # a. To device
             batch = to_device(batch, device=self.device)
@@ -129,7 +133,7 @@
             if segments_part:
                 for batch_num in range(0, self.batch_size):
                     segments[batch_num] += segments_part[batch_num]
-        return segments
+        return fbanks, segments
 
 
 def inference(
@@ -219,9 +223,13 @@
             raw_inputs: Union[np.ndarray, torch.Tensor] = None,
             output_dir_v2: Optional[str] = None,
             fs: dict = None,
-            param_dict: dict = None,
+            param_dict: dict = None
     ):
         # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
         loader = VADTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,
@@ -254,7 +262,7 @@
             assert len(keys) == _bs, f"{len(keys)} != {_bs}"
 
             # do vad segment
-            results = speech2vadsegment(**batch)
+            _, results = speech2vadsegment(**batch)
             for i, _ in enumerate(keys):
                 results[i] = json.dumps(results[i])
                 item = {'key': keys[i], 'value': results[i]}
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index b64c677..c21be1b 100755
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -201,7 +201,7 @@
                                                self.vad_opts.frame_in_ms)
         self.encoder = encoder
         # init variables
-        self.is_final_send = False
+        self.is_final = False
         self.data_buf_start_frame = 0
         self.frm_cnt = 0
         self.latest_confirmed_speech_frame = 0
@@ -230,8 +230,7 @@
         self.ResetDetection()
 
     def AllResetDetection(self):
-        self.encoder.cache_reset()  # reset the in_cache in self.encoder for next query or next long sentence
-        self.is_final_send = False
+        self.is_final = False
         self.data_buf_start_frame = 0
         self.frm_cnt = 0
         self.latest_confirmed_speech_frame = 0
@@ -283,8 +282,8 @@
                 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                 0.000001))
 
-    def ComputeScores(self, feats: torch.Tensor) -> None:
-        scores = self.encoder(feats)  # return B * T * D
+    def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
+        scores = self.encoder(feats, in_cache)  # return B * T * D
         assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
         self.vad_opts.nn_eval_block_size = scores.shape[1]
         self.frm_cnt += scores.shape[1]  # count total frames
@@ -306,7 +305,7 @@
         expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
         if last_frm_is_end_point:
             extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
-                               self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
             expected_sample_number += int(extra_sample)
         if end_point_is_sent_end:
             expected_sample_number = max(expected_sample_number, len(self.data_buf))
@@ -443,11 +442,13 @@
 
         return frame_state
 
-    def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]:
+    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+                is_final: bool = False
+                ) -> List[List[List[int]]]:
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
-        self.ComputeScores(feats)
-        if not is_final_send:
+        self.ComputeScores(feats, in_cache)
+        if not is_final:
             self.DetectCommonFrames()
         else:
             self.DetectLastFrames()
@@ -456,15 +457,18 @@
             segment_batch = []
             if len(self.output_data_buf) > 0:
                 for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[
+                    if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
                         i].contain_seg_end_point:
-                        segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
-                        segment_batch.append(segment)
-                        self.output_data_buf_offset += 1  # need update this parameter
+                        continue
+                    segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
+                    segment_batch.append(segment)
+                    self.output_data_buf_offset += 1  # need update this parameter
             if segment_batch:
                 segments.append(segment_batch)
-        if is_final_send:
-            self.AllResetDetection() 
+        if is_final:
+            # reset class variables and clear the dict for the next query
+            self.AllResetDetection()
+            in_cache.clear()
         return segments
 
     def DetectCommonFrames(self) -> int:
diff --git a/funasr/models/encoder/fsmn_encoder.py b/funasr/models/encoder/fsmn_encoder.py
index 54a113d..c749dc4 100755
--- a/funasr/models/encoder/fsmn_encoder.py
+++ b/funasr/models/encoder/fsmn_encoder.py
@@ -79,14 +79,12 @@
         else:
             self.conv_right = None
 
-    def forward(self, input: torch.Tensor, in_cache=None):
+    def forward(self, input: torch.Tensor, cache: torch.Tensor):
         x = torch.unsqueeze(input, 1)
         x_per = x.permute(0, 3, 2, 1)  # B D T C
-        if in_cache is None:  # offline
-            y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
-        else:
-            y_left = torch.cat((in_cache, x_per), dim=2)
-            in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
+
+        y_left = torch.cat((cache, x_per), dim=2)
+        cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
         y_left = self.conv_left(y_left)
         out = x_per + y_left
 
@@ -100,7 +98,7 @@
         out_per = out.permute(0, 3, 2, 1)
         output = out_per.squeeze(1)
 
-        return output, in_cache
+        return output, cache
 
 
 class BasicBlock(nn.Sequential):
@@ -124,28 +122,25 @@
         self.affine = AffineTransform(proj_dim, linear_dim)
         self.relu = RectifiedLinear(linear_dim, linear_dim)
 
-    def forward(self, input: torch.Tensor, in_cache=None):
+    def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
         x1 = self.linear(input)  # B T D
-        if in_cache is not None:  # Dict[str, tensor.Tensor]
-            cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
-            if cache_layer_name not in in_cache:
-                in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
-            x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
-        else:
-            x2, _ = self.fsmn_block(x1)
+        cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+        if cache_layer_name not in in_cache:
+            in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+        x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
         x3 = self.affine(x2)
         x4 = self.relu(x3)
-        return x4, in_cache
+        return x4
 
 
 class FsmnStack(nn.Sequential):
     def __init__(self, *args):
         super(FsmnStack, self).__init__(*args)
 
-    def forward(self, input: torch.Tensor, in_cache=None):
+    def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
         x = input
         for module in self._modules.values():
-            x, in_cache = module(x, in_cache)
+            x = module(x, in_cache)
         return x
 
 
@@ -174,8 +169,7 @@
             lstride: int,
             rstride: int,
             output_affine_dim: int,
-            output_dim: int,
-            streaming=False
+            output_dim: int
     ):
         super(FSMN, self).__init__()
 
@@ -186,8 +180,6 @@
         self.proj_dim = proj_dim
         self.output_affine_dim = output_affine_dim
         self.output_dim = output_dim
-        self.in_cache_original = dict() if streaming else None
-        self.in_cache = copy.deepcopy(self.in_cache_original)
 
         self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
         self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
@@ -201,12 +193,10 @@
     def fuse_modules(self):
         pass
 
-    def cache_reset(self):
-        self.in_cache = copy.deepcopy(self.in_cache_original)
-
     def forward(
             self,
             input: torch.Tensor,
+            in_cache: Dict[str, torch.Tensor]
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
         """
         Args:
@@ -218,7 +208,7 @@
         x1 = self.in_linear1(input)
         x2 = self.in_linear2(x1)
         x3 = self.relu(x2)
-        x4 = self.fsmn(x3, self.in_cache)  # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
+        x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
         x5 = self.out_linear1(x4)
         x6 = self.out_linear2(x5)
         x7 = self.softmax(x6)
@@ -307,4 +297,4 @@
     print('input shape: {}'.format(x.shape))
     print('output shape: {}'.format(y.shape))
 
-    print(fsmn.to_kaldi_net())
+    print(fsmn.to_kaldi_net())
\ No newline at end of file

--
Gitblit v1.9.1