zhifu gao
2023-02-27 4e506305270c68180ab3c63087c8ac29c78a3c62
Merge pull request #155 from alibaba-damo-academy/dev_zly

in_cache & support soundfile read
4个文件已修改
200 ■■■■ 已修改文件
funasr/bin/asr_inference_paraformer_vad_punc.py 96 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_vad.py 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/fsmn_encoder.py 44 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -43,6 +43,7 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
@@ -363,101 +364,6 @@
        else:
            hotword_list = None
        return hotword_list
class Speech2VadSegment:
    """Speech2VadSegment class
    Examples:
        >>> import soundfile
        >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2segment(audio)
        [[10, 230], [245, 450], ...]
    """
    def __init__(
            self,
            vad_infer_config: Union[Path, str] = None,
            vad_model_file: Union[Path, str] = None,
            vad_cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
            batch_size: int = 1,
            dtype: str = "float32",
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build vad model
        vad_model, vad_infer_args = VADTask.build_model_from_file(
            vad_infer_config, vad_model_file, device
        )
        frontend = None
        if vad_infer_args.frontend is not None:
            frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
        # logging.info("vad_model: {}".format(vad_model))
        # logging.info("vad_infer_args: {}".format(vad_infer_args))
        vad_model.to(dtype=getattr(torch, dtype)).eval()
        self.vad_model = vad_model
        self.vad_infer_args = vad_infer_args
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
        self.batch_size = batch_size
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[List[int]]:
        """Inference
        Args:
            speech: Input speech data
        Returns:
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            self.frontend.filter_length_max = math.inf
            fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
            feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
            fbanks = to_device(fbanks, device=self.device)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
        else:
            raise Exception("Need to extract feats first, please configure frontend configuration")
        # b. Forward Encoder streaming
        t_offset = 0
        step = min(feats_len, 6000)
        segments = [[]] * self.batch_size
        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
            if t_offset + step >= feats_len - 1:
                step = feats_len - t_offset
                is_final_send = True
            else:
                is_final_send = False
            batch = {
                "feats": feats[:, t_offset:t_offset + step, :],
                "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
                "is_final_send": is_final_send
            }
            # a. To device
            batch = to_device(batch, device=self.device)
            segments_part = self.vad_model(**batch)
            if segments_part:
                for batch_num in range(0, self.batch_size):
                    segments[batch_num] += segments_part[batch_num]
        return fbanks, segments
def inference(
funasr/bin/vad_inference.py
@@ -11,6 +11,7 @@
from typing import Union
from typing import Dict
import math
import numpy as np
import torch
from typeguard import check_argument_types
@@ -86,7 +87,7 @@
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[List[int]]:
    ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
        """Inference
        Args:
@@ -102,7 +103,10 @@
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            self.frontend.filter_length_max = math.inf
            fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
            feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
            fbanks = to_device(fbanks, device=self.device)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
        else:
@@ -110,18 +114,18 @@
        # b. Forward Encoder streaming
        t_offset = 0
        step = min(feats_len, 6000)
        step = min(feats_len.max(), 6000)
        segments = [[]] * self.batch_size
        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
            if t_offset + step >= feats_len - 1:
                step = feats_len - t_offset
                is_final_send = True
                is_final = True
            else:
                is_final_send = False
                is_final = False
            batch = {
                "feats": feats[:, t_offset:t_offset + step, :],
                "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
                "is_final_send": is_final_send
                "is_final": is_final
            }
            # a. To device
            batch = to_device(batch, device=self.device)
@@ -129,7 +133,7 @@
            if segments_part:
                for batch_num in range(0, self.batch_size):
                    segments[batch_num] += segments_part[batch_num]
        return segments
        return fbanks, segments
def inference(
@@ -219,9 +223,13 @@
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            fs: dict = None,
            param_dict: dict = None,
            param_dict: dict = None
    ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = VADTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
@@ -254,7 +262,7 @@
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # do vad segment
            results = speech2vadsegment(**batch)
            _, results = speech2vadsegment(**batch)
            for i, _ in enumerate(keys):
                results[i] = json.dumps(results[i])
                item = {'key': keys[i], 'value': results[i]}
funasr/models/e2e_vad.py
@@ -201,7 +201,7 @@
                                               self.vad_opts.frame_in_ms)
        self.encoder = encoder
        # init variables
        self.is_final_send = False
        self.is_final = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
@@ -230,8 +230,7 @@
        self.ResetDetection()
    def AllResetDetection(self):
        self.encoder.cache_reset()  # reset the in_cache in self.encoder for next query or next long sentence
        self.is_final_send = False
        self.is_final = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
@@ -283,8 +282,8 @@
                10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                0.000001))
    def ComputeScores(self, feats: torch.Tensor) -> None:
        scores = self.encoder(feats)  # return B * T * D
    def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
        scores = self.encoder(feats, in_cache)  # return B * T * D
        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        self.frm_cnt += scores.shape[1]  # count total frames
@@ -306,7 +305,7 @@
        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
        if last_frm_is_end_point:
            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
                               self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
            expected_sample_number += int(extra_sample)
        if end_point_is_sent_end:
            expected_sample_number = max(expected_sample_number, len(self.data_buf))
@@ -443,11 +442,13 @@
        return frame_state
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]:
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ) -> List[List[List[int]]]:
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats)
        if not is_final_send:
        self.ComputeScores(feats, in_cache)
        if not is_final:
            self.DetectCommonFrames()
        else:
            self.DetectLastFrames()
@@ -456,15 +457,18 @@
            segment_batch = []
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[
                    if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
                        i].contain_seg_end_point:
                        segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                        segment_batch.append(segment)
                        self.output_data_buf_offset += 1  # need update this parameter
                        continue
                    segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                    segment_batch.append(segment)
                    self.output_data_buf_offset += 1  # need update this parameter
            if segment_batch:
                segments.append(segment_batch)
        if is_final_send:
            self.AllResetDetection()
        if is_final:
            # reset class variables and clear the dict for the next query
            self.AllResetDetection()
            in_cache.clear()
        return segments
    def DetectCommonFrames(self) -> int:
funasr/models/encoder/fsmn_encoder.py
@@ -79,14 +79,12 @@
        else:
            self.conv_right = None
    def forward(self, input: torch.Tensor, in_cache=None):
    def forward(self, input: torch.Tensor, cache: torch.Tensor):
        x = torch.unsqueeze(input, 1)
        x_per = x.permute(0, 3, 2, 1)  # B D T C
        if in_cache is None:  # offline
            y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
        else:
            y_left = torch.cat((in_cache, x_per), dim=2)
            in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
        y_left = torch.cat((cache, x_per), dim=2)
        cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
        y_left = self.conv_left(y_left)
        out = x_per + y_left
@@ -100,7 +98,7 @@
        out_per = out.permute(0, 3, 2, 1)
        output = out_per.squeeze(1)
        return output, in_cache
        return output, cache
class BasicBlock(nn.Sequential):
@@ -124,28 +122,25 @@
        self.affine = AffineTransform(proj_dim, linear_dim)
        self.relu = RectifiedLinear(linear_dim, linear_dim)
    def forward(self, input: torch.Tensor, in_cache=None):
    def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
        x1 = self.linear(input)  # B T D
        if in_cache is not None:  # Dict[str, tensor.Tensor]
            cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
            if cache_layer_name not in in_cache:
                in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
            x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
        else:
            x2, _ = self.fsmn_block(x1)
        cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
        if cache_layer_name not in in_cache:
            in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
        x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
        x3 = self.affine(x2)
        x4 = self.relu(x3)
        return x4, in_cache
        return x4
class FsmnStack(nn.Sequential):
    def __init__(self, *args):
        super(FsmnStack, self).__init__(*args)
    def forward(self, input: torch.Tensor, in_cache=None):
    def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
        x = input
        for module in self._modules.values():
            x, in_cache = module(x, in_cache)
            x = module(x, in_cache)
        return x
@@ -174,8 +169,7 @@
            lstride: int,
            rstride: int,
            output_affine_dim: int,
            output_dim: int,
            streaming=False
            output_dim: int
    ):
        super(FSMN, self).__init__()
@@ -186,8 +180,6 @@
        self.proj_dim = proj_dim
        self.output_affine_dim = output_affine_dim
        self.output_dim = output_dim
        self.in_cache_original = dict() if streaming else None
        self.in_cache = copy.deepcopy(self.in_cache_original)
        self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
        self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
@@ -201,12 +193,10 @@
    def fuse_modules(self):
        pass
    def cache_reset(self):
        self.in_cache = copy.deepcopy(self.in_cache_original)
    def forward(
            self,
            input: torch.Tensor,
            in_cache: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
@@ -218,7 +208,7 @@
        x1 = self.in_linear1(input)
        x2 = self.in_linear2(x1)
        x3 = self.relu(x2)
        x4 = self.fsmn(x3, self.in_cache)  # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
        x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
        x5 = self.out_linear1(x4)
        x6 = self.out_linear2(x5)
        x7 = self.softmax(x6)
@@ -307,4 +297,4 @@
    print('input shape: {}'.format(x.shape))
    print('output shape: {}'.format(y.shape))
    print(fsmn.to_kaldi_net())
    print(fsmn.to_kaldi_net())