游雁
2023-03-29 a030ff0f85fd6b1cc2a1d443d2fcfb11ccb1aa8f
export
3个文件已修改
1个文件已添加
314 ■■■■ 已修改文件
funasr/export/models/__init__.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/encoder/sanm_encoder.py 99 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/target_delay_transformer.py 132 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/vad_realtime_transformer.py 79 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py
@@ -6,6 +6,8 @@
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
def get_model(model, export_config=None):
    if isinstance(model, BiCifParaformer):
@@ -17,5 +19,7 @@
    elif isinstance(model, ESPnetPunctuationModel):
        if isinstance(model.punc_model, TargetDelayTransformer):
            return TargetDelayTransformer_export(model.punc_model, **export_config)
        elif isinstance(model.punc_model, VadRealtimeTransformer):
            return VadRealtimeTransformer_export(model.punc_model, **export_config)
    else:
        raise "Funasr does not support the given model type currently."
funasr/export/models/encoder/sanm_encoder.py
@@ -107,3 +107,102 @@
            }
        }
class SANMVadEncoder(nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        feats_dim=560,
        model_name='encoder',
        onnx: bool = True,
    ):
        super().__init__()
        self.embed = model.embed
        self.model = model
        self.feats_dim = feats_dim
        self._output_size = model._output_size
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        if hasattr(model, 'encoders0'):
            for i, d in enumerate(self.model.encoders0):
                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                    d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
                if isinstance(d.feed_forward, PositionwiseFeedForward):
                    d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
                self.model.encoders0[i] = EncoderLayerSANM_export(d)
        for i, d in enumerate(self.model.encoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
            if isinstance(d.feed_forward, PositionwiseFeedForward):
                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
            self.model.encoders[i] = EncoderLayerSANM_export(d)
        self.model_name = model_name
        self.num_heads = model.encoders[0].self_attn.h
        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                ):
        speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        mask = self.prepare_mask(mask)
        if self.embed is None:
            xs_pad = speech
        else:
            xs_pad = self.embed(speech)
        encoder_outs = self.model.encoders0(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        encoder_outs = self.model.encoders(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        xs_pad = self.model.after_norm(xs_pad)
        return xs_pad, speech_lengths
    def get_output_size(self):
        return self.model.encoders[0].size
    def get_dummy_inputs(self):
        feats = torch.randn(1, 100, self.feats_dim)
        return (feats)
    def get_input_names(self):
        return ['feats']
    def get_output_names(self):
        return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
    def get_dynamic_axes(self):
        return {
            'feats': {
                1: 'feats_length'
            },
            'encoder_out': {
                1: 'enc_out_length'
            },
            'predictor_weight': {
                1: 'pre_out_length'
            }
        }
funasr/export/models/target_delay_transformer.py
@@ -28,7 +28,7 @@
            onnx = kwargs["onnx"]
        self.embed = model.embed
        self.decoder = model.decoder
        self.model = model
        # self.model = model
        self.feats_dim = self.embed.embedding_dim
        self.num_embeddings = self.embed.num_embeddings
        self.model_name = model_name
@@ -46,71 +46,71 @@
        from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
        from funasr.punctuation.abs_model import AbsPunctuation
        class TargetDelayTransformer(nn.Module):
            def __init__(
                    self,
                    model,
                    max_seq_len=512,
                    model_name='punc_model',
                    **kwargs,
            ):
                super().__init__()
                onnx = False
                if "onnx" in kwargs:
                    onnx = kwargs["onnx"]
                self.embed = model.embed
                self.decoder = model.decoder
                self.model = model
                self.feats_dim = self.embed.embedding_dim
                self.num_embeddings = self.embed.num_embeddings
                self.model_name = model_name
                if isinstance(model.encoder, SANMEncoder):
                    self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
                else:
                    assert False, "Only support samn encode."
            def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
                """Compute loss value from buffer sequences.
                Args:
                    input (torch.Tensor): Input ids. (batch, len)
                    hidden (torch.Tensor): Target ids. (batch, len)
                """
                x = self.embed(input)
                # mask = self._target_mask(input)
                h, _ = self.encoder(x, text_lengths)
                y = self.decoder(h)
                return y
            def get_dummy_inputs(self):
                length = 120
                text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
                text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
                return (text_indexes, text_lengths)
            def get_input_names(self):
                return ['input', 'text_lengths']
            def get_output_names(self):
                return ['logits']
            def get_dynamic_axes(self):
                return {
                    'input': {
                        0: 'batch_size',
                        1: 'feats_length'
                    },
                    'text_lengths': {
                        0: 'batch_size',
                    },
                    'logits': {
                        0: 'batch_size',
                        1: 'logits_length'
                    },
                }
        # class TargetDelayTransformer(nn.Module):
        #
        #     def __init__(
        #             self,
        #             model,
        #             max_seq_len=512,
        #             model_name='punc_model',
        #             **kwargs,
        #     ):
        #         super().__init__()
        #         onnx = False
        #         if "onnx" in kwargs:
        #             onnx = kwargs["onnx"]
        #         self.embed = model.embed
        #         self.decoder = model.decoder
        #         self.model = model
        #         self.feats_dim = self.embed.embedding_dim
        #         self.num_embeddings = self.embed.num_embeddings
        #         self.model_name = model_name
        #
        #         if isinstance(model.encoder, SANMEncoder):
        #             self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        #         else:
        #             assert False, "Only support samn encode."
        #
        #     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        #         """Compute loss value from buffer sequences.
        #
        #         Args:
        #             input (torch.Tensor): Input ids. (batch, len)
        #             hidden (torch.Tensor): Target ids. (batch, len)
        #
        #         """
        #         x = self.embed(input)
        #         # mask = self._target_mask(input)
        #         h, _ = self.encoder(x, text_lengths)
        #         y = self.decoder(h)
        #         return y
        #
        #     def get_dummy_inputs(self):
        #         length = 120
        #         text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        #         text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
        #         return (text_indexes, text_lengths)
        #
        #     def get_input_names(self):
        #         return ['input', 'text_lengths']
        #
        #     def get_output_names(self):
        #         return ['logits']
        #
        #     def get_dynamic_axes(self):
        #         return {
        #             'input': {
        #                 0: 'batch_size',
        #                 1: 'feats_length'
        #             },
        #             'text_lengths': {
        #                 0: 'batch_size',
        #             },
        #             'logits': {
        #                 0: 'batch_size',
        #                 1: 'logits_length'
        #             },
        #         }
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
funasr/export/models/vad_realtime_transformer.py
New file
@@ -0,0 +1,79 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class VadRealtimeTransformer(AbsPunctuation):
    def __init__(
        self,
        model,
        max_seq_len=512,
        model_name='punc_model',
        **kwargs,
    ):
        super().__init__()
        self.embed = model.embed
        if isinstance(model.encoder, SANMVadEncoder):
            self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
        else:
            assert False, "Only support samn encode."
        # self.encoder = model.encoder
        self.decoder = model.decoder
    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
                vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(input)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths, vad_indexes)
        y = self.decoder(h)
        return y
    def with_vad(self):
        return True
    def get_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
        return (text_indexes, text_lengths)
    def get_input_names(self):
        return ['input', 'text_lengths']
    def get_output_names(self):
        return ['logits']
    def get_dynamic_axes(self):
        return {
            'input': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'text_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }