九耳
2023-03-30 81d4f772700b9910121442ac985a0339adafe02b
Merge branch 'dev_cmz2' of github.com:alibaba-damo-academy/FunASR into dev_cmz2
6个文件已修改
1个文件已添加
371 ■■■■ 已修改文件
funasr/export/models/__init__.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/encoder/sanm_encoder.py 99 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/target_delay_transformer.py 132 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/vad_realtime_transformer.py 79 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/libtorch/funasr_torch/utils/utils.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py 53 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py
@@ -6,6 +6,8 @@
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
def get_model(model, export_config=None):
    if isinstance(model, BiCifParaformer):
@@ -17,5 +19,7 @@
    elif isinstance(model, ESPnetPunctuationModel):
        if isinstance(model.punc_model, TargetDelayTransformer):
            return TargetDelayTransformer_export(model.punc_model, **export_config)
        elif isinstance(model.punc_model, VadRealtimeTransformer):
            return VadRealtimeTransformer_export(model.punc_model, **export_config)
    else:
        raise "Funasr does not support the given model type currently."
funasr/export/models/encoder/sanm_encoder.py
@@ -107,3 +107,102 @@
            }
        }
class SANMVadEncoder(nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        feats_dim=560,
        model_name='encoder',
        onnx: bool = True,
    ):
        super().__init__()
        self.embed = model.embed
        self.model = model
        self.feats_dim = feats_dim
        self._output_size = model._output_size
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        if hasattr(model, 'encoders0'):
            for i, d in enumerate(self.model.encoders0):
                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                    d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
                if isinstance(d.feed_forward, PositionwiseFeedForward):
                    d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
                self.model.encoders0[i] = EncoderLayerSANM_export(d)
        for i, d in enumerate(self.model.encoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
            if isinstance(d.feed_forward, PositionwiseFeedForward):
                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
            self.model.encoders[i] = EncoderLayerSANM_export(d)
        self.model_name = model_name
        self.num_heads = model.encoders[0].self_attn.h
        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                ):
        speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        mask = self.prepare_mask(mask)
        if self.embed is None:
            xs_pad = speech
        else:
            xs_pad = self.embed(speech)
        encoder_outs = self.model.encoders0(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        encoder_outs = self.model.encoders(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        xs_pad = self.model.after_norm(xs_pad)
        return xs_pad, speech_lengths
    def get_output_size(self):
        return self.model.encoders[0].size
    def get_dummy_inputs(self):
        feats = torch.randn(1, 100, self.feats_dim)
        return (feats)
    def get_input_names(self):
        return ['feats']
    def get_output_names(self):
        return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
    def get_dynamic_axes(self):
        return {
            'feats': {
                1: 'feats_length'
            },
            'encoder_out': {
                1: 'enc_out_length'
            },
            'predictor_weight': {
                1: 'pre_out_length'
            }
        }
funasr/export/models/target_delay_transformer.py
@@ -28,7 +28,7 @@
            onnx = kwargs["onnx"]
        self.embed = model.embed
        self.decoder = model.decoder
        self.model = model
        # self.model = model
        self.feats_dim = self.embed.embedding_dim
        self.num_embeddings = self.embed.num_embeddings
        self.model_name = model_name
@@ -46,71 +46,71 @@
        from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
        from funasr.punctuation.abs_model import AbsPunctuation
        class TargetDelayTransformer(nn.Module):
            def __init__(
                    self,
                    model,
                    max_seq_len=512,
                    model_name='punc_model',
                    **kwargs,
            ):
                super().__init__()
                onnx = False
                if "onnx" in kwargs:
                    onnx = kwargs["onnx"]
                self.embed = model.embed
                self.decoder = model.decoder
                self.model = model
                self.feats_dim = self.embed.embedding_dim
                self.num_embeddings = self.embed.num_embeddings
                self.model_name = model_name
                if isinstance(model.encoder, SANMEncoder):
                    self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
                else:
                    assert False, "Only support samn encode."
            def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
                """Compute loss value from buffer sequences.
                Args:
                    input (torch.Tensor): Input ids. (batch, len)
                    hidden (torch.Tensor): Target ids. (batch, len)
                """
                x = self.embed(input)
                # mask = self._target_mask(input)
                h, _ = self.encoder(x, text_lengths)
                y = self.decoder(h)
                return y
            def get_dummy_inputs(self):
                length = 120
                text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
                text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
                return (text_indexes, text_lengths)
            def get_input_names(self):
                return ['input', 'text_lengths']
            def get_output_names(self):
                return ['logits']
            def get_dynamic_axes(self):
                return {
                    'input': {
                        0: 'batch_size',
                        1: 'feats_length'
                    },
                    'text_lengths': {
                        0: 'batch_size',
                    },
                    'logits': {
                        0: 'batch_size',
                        1: 'logits_length'
                    },
                }
        # class TargetDelayTransformer(nn.Module):
        #
        #     def __init__(
        #             self,
        #             model,
        #             max_seq_len=512,
        #             model_name='punc_model',
        #             **kwargs,
        #     ):
        #         super().__init__()
        #         onnx = False
        #         if "onnx" in kwargs:
        #             onnx = kwargs["onnx"]
        #         self.embed = model.embed
        #         self.decoder = model.decoder
        #         self.model = model
        #         self.feats_dim = self.embed.embedding_dim
        #         self.num_embeddings = self.embed.num_embeddings
        #         self.model_name = model_name
        #
        #         if isinstance(model.encoder, SANMEncoder):
        #             self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        #         else:
        #             assert False, "Only support samn encode."
        #
        #     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        #         """Compute loss value from buffer sequences.
        #
        #         Args:
        #             input (torch.Tensor): Input ids. (batch, len)
        #             hidden (torch.Tensor): Target ids. (batch, len)
        #
        #         """
        #         x = self.embed(input)
        #         # mask = self._target_mask(input)
        #         h, _ = self.encoder(x, text_lengths)
        #         y = self.decoder(h)
        #         return y
        #
        #     def get_dummy_inputs(self):
        #         length = 120
        #         text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        #         text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
        #         return (text_indexes, text_lengths)
        #
        #     def get_input_names(self):
        #         return ['input', 'text_lengths']
        #
        #     def get_output_names(self):
        #         return ['logits']
        #
        #     def get_dynamic_axes(self):
        #         return {
        #             'input': {
        #                 0: 'batch_size',
        #                 1: 'feats_length'
        #             },
        #             'text_lengths': {
        #                 0: 'batch_size',
        #             },
        #             'logits': {
        #                 0: 'batch_size',
        #                 1: 'logits_length'
        #             },
        #         }
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
funasr/export/models/vad_realtime_transformer.py
New file
@@ -0,0 +1,79 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class VadRealtimeTransformer(AbsPunctuation):
    def __init__(
        self,
        model,
        max_seq_len=512,
        model_name='punc_model',
        **kwargs,
    ):
        super().__init__()
        self.embed = model.embed
        if isinstance(model.encoder, SANMVadEncoder):
            self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
        else:
            assert False, "Only support samn encode."
        # self.encoder = model.encoder
        self.decoder = model.decoder
    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
                vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(input)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths, vad_indexes)
        y = self.decoder(h)
        return y
    def with_vad(self):
        return True
    def get_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
        return (text_indexes, text_lengths)
    def get_input_names(self):
        return ['input', 'text_lengths']
    def get_output_names(self):
        return ['logits']
    def get_dynamic_axes(self):
        return {
            'input': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'text_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -134,7 +134,7 @@
@functools.lru_cache()
def get_logger(name='torch_paraformer'):
def get_logger(name='funasr_torch'):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -239,7 +239,7 @@
@functools.lru_cache()
def get_logger(name='rapdi_paraformer'):
def get_logger(name='funasr_onnx'):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -59,33 +59,38 @@
        
    
    def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        # waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
        is_final = kwargs.get('kwargs', False)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
        param_dict = kwargs.get('param_dict', dict())
        audio_in_cache = param_dict.get('audio_in_cache', None)
        audio_in_cum = audio_in
        if audio_in_cache is not None:
            audio_in_cum = np.concatenate((audio_in_cache, audio_in_cum))
        param_dict['audio_in_cache'] = audio_in_cum
        feats, feats_len = self.extract_feat([audio_in_cum])
        in_cache = param_dict.get('in_cache', list())
        in_cache = self.prepare_cache(in_cache)
        beg_idx = param_dict.get('beg_idx',0)
        feats = feats[:, beg_idx:beg_idx+8, :]
        param_dict['beg_idx'] = beg_idx + feats.shape[1]
        try:
            inputs = [feats]
            inputs.extend(in_cache)
            scores, out_caches = self.infer(inputs)
            param_dict['in_cache'] = out_caches
            segments = self.vad_scorer(scores, audio_in[None, :], is_final=is_final, max_end_sil=self.max_end_sil)
            # print(segments)
            if len(segments) == 1 and segments[0][0][1] != -1:
                self.frontend.reset_status()
            
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            waveform = waveform_list[beg_idx:end_idx]
            feats, feats_len = self.extract_feat(waveform)
            param_dict = kwargs.get('param_dict', dict())
            in_cache = param_dict.get('in_cache', list())
            in_cache = self.prepare_cache(in_cache)
            try:
                inputs = [feats]
                inputs.extend(in_cache)
                scores, out_caches = self.infer(inputs)
                param_dict['in_cache'] = out_caches
                segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
            except ONNXRuntimeError:
                # logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                segments = ''
            asr_res.append(segments)
        except ONNXRuntimeError:
            logging.warning(traceback.format_exc())
            logging.warning("input wav is silence or noise")
            segments = []
    
        return asr_res
        return segments
    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: