nichongjia-2007
2023-07-07 e5151e047479e3414ed2faa2890bc3e7e17259be
add language models
6个文件已添加
885 ■■■■■ 已修改文件
funasr/export/models/e2e_asr_conformer.py 103 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/language_models/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/language_models/embed.py 403 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/language_models/seq_rnn.py 84 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/language_models/subsampling.py 185 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/language_models/transformer.py 110 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_asr_conformer.py
New file
@@ -0,0 +1,103 @@
import logging
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
from funasr.models.decoder.transformer_decoder import TransformerDecoder as TransformerDecoder_export
class Conformer(nn.Module):
    """
    export conformer into onnx format
    """
    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            output_size=2048,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        if isinstance(model.encoder, ConformerEncoder):
            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
        elif isinstance(model.decoder, TransformerDecoder):
            self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
        self.feats_dim = feats_dim
        self.output_size = output_size
        self.model_name = model_name
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        # fill the decoder input
        enc_size = self.encoder.output_size
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoder)
        cache = [
            torch.zeros((1, self.decoder.size, self.decoder.self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        decoder_out, olens = self.decoder(enc, enc_len, pre_acoustic_embeds, cache)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # sample_ids = decoder_out.argmax(dim=-1)
        return decoder_out, olens
    def get_dummy_inputs(self):
        speech = torch.randn(2, 30, self.feats_dim)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
        import numpy as np
        fbank = np.loadtxt(txt_file)
        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
        return (speech, speech_lengths)
    def get_input_names(self):
        return ['speech', 'speech_lengths']
    def get_output_names(self):
        return ['logits', 'token_num']
    def get_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
funasr/export/models/language_models/__init__.py
funasr/export/models/language_models/embed.py
New file
@@ -0,0 +1,403 @@
"""Positional Encoding Module."""
import math
import torch
import torch.nn as nn
from funasr.modules.embedding import (
    LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
    ScaledPositionalEncoding, StreamPositionalEncoding)
from funasr.modules.subsampling import (
    Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
    Conv2dSubsampling8)
from funasr.modules.subsampling_without_posenc import \
    Conv2dSubsamplingWOPosEnc
from funasr.export.models.language_models.subsampling import (
    OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
    OnnxConv2dSubsampling8)
def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
    if isinstance(pos_emb, LegacyRelPositionalEncoding):
        return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
    elif isinstance(pos_emb, ScaledPositionalEncoding):
        return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
    elif isinstance(pos_emb, RelPositionalEncoding):
        return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
    elif isinstance(pos_emb, PositionalEncoding):
        return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
    elif isinstance(pos_emb, StreamPositionalEncoding):
        return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
    elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
        isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
    ):
        return pos_emb
    else:
        raise ValueError("Embedding model is not supported.")
class Embedding(nn.Module):
    def __init__(self, model, max_seq_len=512, use_cache=True):
        super().__init__()
        self.model = model
        if not isinstance(model, nn.Embedding):
            if isinstance(model, Conv2dSubsampling):
                self.model = OnnxConv2dSubsampling(model)
                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
            elif isinstance(model, Conv2dSubsampling2):
                self.model = OnnxConv2dSubsampling2(model)
                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
            elif isinstance(model, Conv2dSubsampling6):
                self.model = OnnxConv2dSubsampling6(model)
                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
            elif isinstance(model, Conv2dSubsampling8):
                self.model = OnnxConv2dSubsampling8(model)
                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
            else:
                self.model[-1] = get_pos_emb(model[-1], max_seq_len)
    def forward(self, x, mask=None):
        if mask is None:
            return self.model(x)
        else:
            return self.model(x, mask)
def _pre_hook(
    state_dict,
    prefix,
    local_metadata,
    strict,
    missing_keys,
    unexpected_keys,
    error_msgs,
):
    """Perform pre-hook in load_state_dict for backward compatibility.
    Note:
        We saved self.pe until v.0.5.2 but we have omitted it later.
        Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
    """
    k = prefix + "pe"
    if k in state_dict:
        state_dict.pop(k)
class OnnxPositionalEncoding(torch.nn.Module):
    """Positional encoding.
    Args:
        d_model (int): Embedding dimension.
        dropout_rate (float): Dropout rate.
        max_seq_len (int): Maximum input length.
        reverse (bool): Whether to reverse the input position. Only for
        the class LegacyRelPositionalEncoding. We remove it in the current
        class RelPositionalEncoding.
    """
    def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
        """Construct an PositionalEncoding object."""
        super(OnnxPositionalEncoding, self).__init__()
        self.d_model = model.d_model
        self.reverse = reverse
        self.max_seq_len = max_seq_len
        self.xscale = math.sqrt(self.d_model)
        self._register_load_state_dict_pre_hook(_pre_hook)
        self.pe = model.pe
        self.use_cache = use_cache
        self.model = model
        if self.use_cache:
            self.extend_pe()
        else:
            self.div_term = torch.exp(
                torch.arange(0, self.d_model, 2, dtype=torch.float32)
                * -(math.log(10000.0) / self.d_model)
            )
    def extend_pe(self):
        """Reset the positional encodings."""
        pe_length = len(self.pe[0])
        if self.max_seq_len < pe_length:
            self.pe = self.pe[:, : self.max_seq_len]
        else:
            self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
            self.pe = self.model.pe
    def _add_pe(self, x):
        """Computes positional encoding"""
        if self.reverse:
            position = torch.arange(
                x.size(1) - 1, -1, -1.0, dtype=torch.float32
            ).unsqueeze(1)
        else:
            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        x = x * self.xscale
        x[:, :, 0::2] += torch.sin(position * self.div_term)
        x[:, :, 1::2] += torch.cos(position * self.div_term)
        return x
    def forward(self, x: torch.Tensor):
        """Add positional encoding.
        Args:
            x (torch.Tensor): Input tensor (batch, time, `*`).
        Returns:
            torch.Tensor: Encoded tensor (batch, time, `*`).
        """
        if self.use_cache:
            x = x * self.xscale + self.pe[:, : x.size(1)]
        else:
            x = self._add_pe(x)
        return x
class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
    """Scaled positional encoding module.
    See Sec. 3.2  https://arxiv.org/abs/1809.08895
    Args:
        d_model (int): Embedding dimension.
        dropout_rate (float): Dropout rate.
        max_seq_len (int): Maximum input length.
    """
    def __init__(self, model, max_seq_len=512, use_cache=True):
        """Initialize class."""
        super().__init__(model, max_seq_len, use_cache=use_cache)
        self.alpha = torch.nn.Parameter(torch.tensor(1.0))
    def reset_parameters(self):
        """Reset parameters."""
        self.alpha.data = torch.tensor(1.0)
    def _add_pe(self, x):
        """Computes positional encoding"""
        if self.reverse:
            position = torch.arange(
                x.size(1) - 1, -1, -1.0, dtype=torch.float32
            ).unsqueeze(1)
        else:
            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        x = x * self.alpha
        x[:, :, 0::2] += torch.sin(position * self.div_term)
        x[:, :, 1::2] += torch.cos(position * self.div_term)
        return x
    def forward(self, x):
        """Add positional encoding.
        Args:
            x (torch.Tensor): Input tensor (batch, time, `*`).
        Returns:
            torch.Tensor: Encoded tensor (batch, time, `*`).
        """
        if self.use_cache:
            x = x + self.alpha * self.pe[:, : x.size(1)]
        else:
            x = self._add_pe(x)
        return x
class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
    """Relative positional encoding module (old version).
    Details can be found in https://github.com/espnet/espnet/pull/2816.
    See : Appendix B in https://arxiv.org/abs/1901.02860
    Args:
        d_model (int): Embedding dimension.
        dropout_rate (float): Dropout rate.
        max_seq_len (int): Maximum input length.
    """
    def __init__(self, model, max_seq_len=512, use_cache=True):
        """Initialize class."""
        super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
    def _get_pe(self, x):
        """Computes positional encoding"""
        if self.reverse:
            position = torch.arange(
                x.size(1) - 1, -1, -1.0, dtype=torch.float32
            ).unsqueeze(1)
        else:
            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        pe = torch.zeros(x.shape)
        pe[:, :, 0::2] += torch.sin(position * self.div_term)
        pe[:, :, 1::2] += torch.cos(position * self.div_term)
        return pe
    def forward(self, x):
        """Compute positional encoding.
        Args:
            x (torch.Tensor): Input tensor (batch, time, `*`).
        Returns:
            torch.Tensor: Encoded tensor (batch, time, `*`).
            torch.Tensor: Positional embedding tensor (1, time, `*`).
        """
        x = x * self.xscale
        if self.use_cache:
            pos_emb = self.pe[:, : x.size(1)]
        else:
            pos_emb = self._get_pe(x)
        return x, pos_emb
class OnnxRelPositionalEncoding(torch.nn.Module):
    """Relative positional encoding module (new implementation).
    Details can be found in https://github.com/espnet/espnet/pull/2816.
    See : Appendix B in https://arxiv.org/abs/1901.02860
    Args:
        d_model (int): Embedding dimension.
        dropout_rate (float): Dropout rate.
        max_seq_len (int): Maximum input length.
    """
    def __init__(self, model, max_seq_len=512, use_cache=True):
        """Construct an PositionalEncoding object."""
        super(OnnxRelPositionalEncoding, self).__init__()
        self.d_model = model.d_model
        self.xscale = math.sqrt(self.d_model)
        self.pe = None
        self.use_cache = use_cache
        if self.use_cache:
            self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
        else:
            self.div_term = torch.exp(
                torch.arange(0, self.d_model, 2, dtype=torch.float32)
                * -(math.log(10000.0) / self.d_model)
            )
    def extend_pe(self, x):
        """Reset the positional encodings."""
        if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
            # self.pe contains both positive and negative parts
            # the length of self.pe is 2 * input_len - 1
            if self.pe.dtype != x.dtype or self.pe.device != x.device:
                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
            return
        # Suppose `i` means to the position of query vecotr and `j` means the
        # position of key vector. We use position relative positions when keys
        # are to the left (i>j) and negative relative positions otherwise (i<j).
        pe_positive = torch.zeros(x.size(1), self.d_model)
        pe_negative = torch.zeros(x.size(1), self.d_model)
        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.d_model)
        )
        pe_positive[:, 0::2] = torch.sin(position * div_term)
        pe_positive[:, 1::2] = torch.cos(position * div_term)
        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
        # Reserve the order of positive indices and concat both positive and
        # negative indices. This is used to support the shifting trick
        # as in https://arxiv.org/abs/1901.02860
        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
        pe_negative = pe_negative[1:].unsqueeze(0)
        pe = torch.cat([pe_positive, pe_negative], dim=1)
        self.pe = pe.to(device=x.device, dtype=x.dtype)
    def _get_pe(self, x):
        pe_positive = torch.zeros(x.size(1), self.d_model)
        pe_negative = torch.zeros(x.size(1), self.d_model)
        theta = (
            torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
        )
        pe_positive[:, 0::2] = torch.sin(theta)
        pe_positive[:, 1::2] = torch.cos(theta)
        pe_negative[:, 0::2] = -1 * torch.sin(theta)
        pe_negative[:, 1::2] = torch.cos(theta)
        # Reserve the order of positive indices and concat both positive and
        # negative indices. This is used to support the shifting trick
        # as in https://arxiv.org/abs/1901.02860
        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
        pe_negative = pe_negative[1:].unsqueeze(0)
        return torch.cat([pe_positive, pe_negative], dim=1)
    def forward(self, x: torch.Tensor, use_cache=True):
        """Add positional encoding.
        Args:
            x (torch.Tensor): Input tensor (batch, time, `*`).
        Returns:
            torch.Tensor: Encoded tensor (batch, time, `*`).
        """
        x = x * self.xscale
        if self.use_cache:
            pos_emb = self.pe[
                :,
                self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
            ]
        else:
            pos_emb = self._get_pe(x)
        return x, pos_emb
class OnnxStreamPositionalEncoding(torch.nn.Module):
    """Streaming Positional encoding."""
    def __init__(self, model, max_seq_len=5000, use_cache=True):
        """Construct an PositionalEncoding object."""
        super(StreamPositionalEncoding, self).__init__()
        self.use_cache = use_cache
        self.d_model = model.d_model
        self.xscale = model.xscale
        self.pe = model.pe
        self.use_cache = use_cache
        self.max_seq_len = max_seq_len
        if self.use_cache:
            self.extend_pe()
        else:
            self.div_term = torch.exp(
                torch.arange(0, self.d_model, 2, dtype=torch.float32)
                * -(math.log(10000.0) / self.d_model)
            )
        self._register_load_state_dict_pre_hook(_pre_hook)
    def extend_pe(self):
        """Reset the positional encodings."""
        pe_length = len(self.pe[0])
        if self.max_seq_len < pe_length:
            self.pe = self.pe[:, : self.max_seq_len]
        else:
            self.model.extend_pe(self.max_seq_len)
            self.pe = self.model.pe
    def _add_pe(self, x, start_idx):
        position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
        x = x * self.xscale
        x[:, :, 0::2] += torch.sin(position * self.div_term)
        x[:, :, 1::2] += torch.cos(position * self.div_term)
        return x
    def forward(self, x: torch.Tensor, start_idx: int = 0):
        """Add positional encoding.
        Args:
            x (torch.Tensor): Input tensor (batch, time, `*`).
        Returns:
            torch.Tensor: Encoded tensor (batch, time, `*`).
        """
        if self.use_cache:
            return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
        else:
            return self._add_pe(x, start_idx)
funasr/export/models/language_models/seq_rnn.py
New file
@@ -0,0 +1,84 @@
import os
import torch
import torch.nn as nn
class SequentialRNNLM(nn.Module):
    def __init__(self, model, **kwargs):
        super().__init__()
        self.encoder = model.encoder
        self.rnn = model.rnn
        self.rnn_type = model.rnn_type
        self.decoder = model.decoder
        self.nlayers = model.nlayers
        self.nhid = model.nhid
        self.model_name = "seq_rnnlm"
    def forward(self, y, hidden1, hidden2=None):
        # batch_score function.
        emb = self.encoder(y)
        if self.rnn_type == "LSTM":
            output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
        else:
            output, hidden1 = self.rnn(emb, hidden1)
        decoded = self.decoder(
            output.contiguous().view(output.size(0) * output.size(1), output.size(2))
        )
        if self.rnn_type == "LSTM":
            return (
                decoded.view(output.size(0), output.size(1), decoded.size(1)),
                hidden1,
                hidden2,
            )
        else:
            return (
                decoded.view(output.size(0), output.size(1), decoded.size(1)),
                hidden1,
            )
    def get_dummy_inputs(self):
        tgt = torch.LongTensor([0, 1]).unsqueeze(0)
        hidden = torch.randn(self.nlayers, 1, self.nhid)
        if self.rnn_type == "LSTM":
            return (tgt, hidden, hidden)
        else:
            return (tgt, hidden)
    def get_input_names(self):
        if self.rnn_type == "LSTM":
            return ["x", "in_hidden1", "in_hidden2"]
        else:
            return ["x", "in_hidden1"]
    def get_output_names(self):
        if self.rnn_type == "LSTM":
            return ["y", "out_hidden1", "out_hidden2"]
        else:
            return ["y", "out_hidden1"]
    def get_dynamic_axes(self):
        ret = {
            "x": {0: "x_batch", 1: "x_length"},
            "y": {0: "y_batch"},
            "in_hidden1": {1: "hidden1_batch"},
            "out_hidden1": {1: "out_hidden1_batch"},
        }
        if self.rnn_type == "LSTM":
            ret.update(
                {
                    "in_hidden2": {1: "hidden2_batch"},
                    "out_hidden2": {1: "out_hidden2_batch"},
                }
            )
        return ret
    def get_model_config(self, path):
        return {
            "use_lm": True,
            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
            "lm_type": "SequentialRNNLM",
            "rnn_type": self.rnn_type,
            "nhid": self.nhid,
            "nlayers": self.nlayers,
        }
funasr/export/models/language_models/subsampling.py
New file
@@ -0,0 +1,185 @@
"""Subsampling layer definition."""
import torch
class OnnxConv2dSubsampling(torch.nn.Module):
    """Convolutional 2D subsampling (to 1/4 length).
    Args:
        idim (int): Input dimension.
        odim (int): Output dimension.
        dropout_rate (float): Dropout rate.
        pos_enc (torch.nn.Module): Custom position encoding layer.
    """
    def __init__(self, model):
        """Construct an Conv2dSubsampling object."""
        super().__init__()
        self.conv = model.conv
        self.out = model.out
    def forward(self, x, x_mask):
        """Subsample x.
        Args:
            x (torch.Tensor): Input tensor (#batch, time, idim).
            x_mask (torch.Tensor): Input mask (#batch, 1, time).
        Returns:
            torch.Tensor: Subsampled tensor (#batch, time', odim),
                where time' = time // 4.
            torch.Tensor: Subsampled mask (#batch, 1, time'),
                where time' = time // 4.
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
        if x_mask is None:
            return x, None
        return x, x_mask[:, :-2:2][:, :-2:2]
    def __getitem__(self, key):
        """Get item.
        When reset_parameters() is called, if use_scaled_pos_enc is used,
            return the positioning encoding.
        """
        if key != -1:
            raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
        return self.out[key]
class OnnxConv2dSubsampling2(torch.nn.Module):
    """Convolutional 2D subsampling (to 1/2 length).
    Args:
        idim (int): Input dimension.
        odim (int): Output dimension.
        dropout_rate (float): Dropout rate.
        pos_enc (torch.nn.Module): Custom position encoding layer.
    """
    def __init__(self, model):
        """Construct an Conv2dSubsampling object."""
        super().__init__()
        self.conv = model.conv
        self.out = model.out
    def forward(self, x, x_mask):
        """Subsample x.
        Args:
            x (torch.Tensor): Input tensor (#batch, time, idim).
            x_mask (torch.Tensor): Input mask (#batch, 1, time).
        Returns:
            torch.Tensor: Subsampled tensor (#batch, time', odim),
                where time' = time // 2.
            torch.Tensor: Subsampled mask (#batch, 1, time'),
                where time' = time // 2.
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
        if x_mask is None:
            return x, None
        return x, x_mask[:, :-2:2][:, :-2:1]
    def __getitem__(self, key):
        """Get item.
        When reset_parameters() is called, if use_scaled_pos_enc is used,
            return the positioning encoding.
        """
        if key != -1:
            raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
        return self.out[key]
class OnnxConv2dSubsampling6(torch.nn.Module):
    """Convolutional 2D subsampling (to 1/6 length).
    Args:
        idim (int): Input dimension.
        odim (int): Output dimension.
        dropout_rate (float): Dropout rate.
        pos_enc (torch.nn.Module): Custom position encoding layer.
    """
    def __init__(self, model):
        """Construct an Conv2dSubsampling object."""
        super().__init__()
        self.conv = model.conv
        self.out = model.out
    def forward(self, x, x_mask):
        """Subsample x.
        Args:
            x (torch.Tensor): Input tensor (#batch, time, idim).
            x_mask (torch.Tensor): Input mask (#batch, 1, time).
        Returns:
            torch.Tensor: Subsampled tensor (#batch, time', odim),
                where time' = time // 6.
            torch.Tensor: Subsampled mask (#batch, 1, time'),
                where time' = time // 6.
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
        if x_mask is None:
            return x, None
        return x, x_mask[:, :-2:2][:, :-4:3]
class OnnxConv2dSubsampling8(torch.nn.Module):
    """Convolutional 2D subsampling (to 1/8 length).
    Args:
        idim (int): Input dimension.
        odim (int): Output dimension.
        dropout_rate (float): Dropout rate.
        pos_enc (torch.nn.Module): Custom position encoding layer.
    """
    def __init__(self, model):
        """Construct an Conv2dSubsampling object."""
        super().__init__()
        self.conv = model.conv
        self.out = model.out
    def forward(self, x, x_mask):
        """Subsample x.
        Args:
            x (torch.Tensor): Input tensor (#batch, time, idim).
            x_mask (torch.Tensor): Input mask (#batch, 1, time).
        Returns:
            torch.Tensor: Subsampled tensor (#batch, time', odim),
                where time' = time // 8.
            torch.Tensor: Subsampled mask (#batch, 1, time'),
                where time' = time // 8.
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        b, c, t, f = x.size()
        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
        if x_mask is None:
            return x, None
        return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]
funasr/export/models/language_models/transformer.py
New file
@@ -0,0 +1,110 @@
import os
import torch
import torch.nn as nn
from funasr.modules.vgg2l import import VGG2L
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.subsampling import (
    Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
from funasr.export.models.language_models.embed import Embedding
from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
from funasr.export.utils.torch_function import MakePadMask
class TransformerLM(nn.Module, AbsExportModel):
    def __init__(self, model, max_seq_len=512, **kwargs):
        super().__init__()
        self.embed = Embedding(model.embed, max_seq_len)
        self.encoder = model.encoder
        self.decoder = model.decoder
        self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        # replace multihead attention module into customized module.
        for i, d in enumerate(self.encoder.encoders):
            # d is EncoderLayer
            if isinstance(d.self_attn, MultiHeadedAttention):
                d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
            self.encoder.encoders[i] = OnnxEncoderLayer(d)
        self.model_name = "transformer_lm"
        self.num_heads = self.encoder.encoders[0].self_attn.h
        self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
    def prepare_mask(self, mask):
        if len(mask.shape) == 2:
            mask = mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask = mask[:, None, :]
        mask = 1 - mask
        return mask * -10000.0
    def forward(self, y, cache):
        feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
        mask = self.make_pad_mask(feats_length)  # (B, T)
        mask = (y != 0) * mask
        xs = self.embed(y)
        # forward_one_step of Encoder
        if isinstance(
            self.encoder.embed,
            (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
        ):
            xs, mask = self.encoder.embed(xs, mask)
        else:
            xs = self.encoder.embed(xs)
        new_cache = []
        mask = self.prepare_mask(mask)
        for c, e in zip(cache, self.encoder.encoders):
            xs, mask = e(xs, mask, c)
            new_cache.append(xs)
        if self.encoder.normalize_before:
            xs = self.encoder.after_norm(xs)
        h = self.decoder(xs[:, -1])
        return h, new_cache
    def get_dummy_inputs(self):
        tgt = torch.LongTensor([1]).unsqueeze(0)
        cache = [
            torch.zeros((1, 1, self.encoder.encoders[0].size))
            for _ in range(len(self.encoder.encoders))
        ]
        return (tgt, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
    def get_output_names(self):
        return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
    def get_dynamic_axes(self):
        ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
        ret.update(
            {
                "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
                for d in range(len(self.encoder.encoders))
            }
        )
        ret.update(
            {
                "out_cache_%d"
                % d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
                for d in range(len(self.encoder.encoders))
            }
        )
        return ret
    def get_model_config(self, path):
        return {
            "use_lm": True,
            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
            "lm_type": "TransformerLM",
            "odim": self.encoder.encoders[0].size,
            "nlayers": len(self.encoder.encoders),
        }