From e5151e047479e3414ed2faa2890bc3e7e17259be Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 07 七月 2023 16:53:16 +0800
Subject: [PATCH] add language models
---
funasr/export/models/language_models/transformer.py | 110 +++++++
funasr/export/models/language_models/embed.py | 403 ++++++++++++++++++++++++++
funasr/export/models/e2e_asr_conformer.py | 103 ++++++
funasr/export/models/language_models/subsampling.py | 185 ++++++++++++
funasr/export/models/language_models/__init__.py | 0
funasr/export/models/language_models/seq_rnn.py | 84 +++++
6 files changed, 885 insertions(+), 0 deletions(-)
diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
new file mode 100644
index 0000000..49c9aae
--- /dev/null
+++ b/funasr/export/models/e2e_asr_conformer.py
@@ -0,0 +1,103 @@
+import logging
+import torch
+import torch.nn as nn
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
+from funasr.models.decoder.transformer_decoder import TransformerDecoder as TransformerDecoder_export
+
+
+class Conformer(nn.Module):
+ """
+ export conformer into onnx format
+ """
+
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ output_size=2048,
+ model_name='model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+ if isinstance(model.encoder, ConformerEncoder):
+ self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+ elif isinstance(model.decoder, TransformerDecoder):
+ self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
+
+ self.feats_dim = feats_dim
+ self.output_size = output_size
+ self.model_name = model_name
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ # batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+
+ # fill the decoder input
+ enc_size = self.encoder.output_size
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoder)
+ cache = [
+ torch.zeros((1, self.decoder.size, self.decoder.self_attn.kernel_size))
+ for _ in range(cache_num)
+ ]
+
+ decoder_out, olens = self.decoder(enc, enc_len, pre_acoustic_embeds, cache)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ # sample_ids = decoder_out.argmax(dim=-1)
+
+ return decoder_out, olens
+
+ def get_dummy_inputs(self):
+ speech = torch.randn(2, 30, self.feats_dim)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+ def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
+ import numpy as np
+ fbank = np.loadtxt(txt_file)
+ fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
+ speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
+ speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
+ return (speech, speech_lengths)
+
+ def get_input_names(self):
+ return ['speech', 'speech_lengths']
+
+ def get_output_names(self):
+ return ['logits', 'token_num']
+
+ def get_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
diff --git a/funasr/export/models/language_models/__init__.py b/funasr/export/models/language_models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/export/models/language_models/__init__.py
diff --git a/funasr/export/models/language_models/embed.py b/funasr/export/models/language_models/embed.py
new file mode 100644
index 0000000..57748f2
--- /dev/null
+++ b/funasr/export/models/language_models/embed.py
@@ -0,0 +1,403 @@
+"""Positional Encoding Module."""
+
+import math
+
+import torch
+import torch.nn as nn
+from funasr.modules.embedding import (
+ LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
+ ScaledPositionalEncoding, StreamPositionalEncoding)
+from funasr.modules.subsampling import (
+ Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
+ Conv2dSubsampling8)
+from funasr.modules.subsampling_without_posenc import \
+ Conv2dSubsamplingWOPosEnc
+
+from funasr.export.models.language_models.subsampling import (
+ OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
+ OnnxConv2dSubsampling8)
+
+
+def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
+ if isinstance(pos_emb, LegacyRelPositionalEncoding):
+ return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
+ elif isinstance(pos_emb, ScaledPositionalEncoding):
+ return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
+ elif isinstance(pos_emb, RelPositionalEncoding):
+ return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
+ elif isinstance(pos_emb, PositionalEncoding):
+ return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
+ elif isinstance(pos_emb, StreamPositionalEncoding):
+ return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
+ elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
+ isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
+ ):
+ return pos_emb
+ else:
+ raise ValueError("Embedding model is not supported.")
+
+
+class Embedding(nn.Module):
+ def __init__(self, model, max_seq_len=512, use_cache=True):
+ super().__init__()
+ self.model = model
+ if not isinstance(model, nn.Embedding):
+ if isinstance(model, Conv2dSubsampling):
+ self.model = OnnxConv2dSubsampling(model)
+ self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+ elif isinstance(model, Conv2dSubsampling2):
+ self.model = OnnxConv2dSubsampling2(model)
+ self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+ elif isinstance(model, Conv2dSubsampling6):
+ self.model = OnnxConv2dSubsampling6(model)
+ self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+ elif isinstance(model, Conv2dSubsampling8):
+ self.model = OnnxConv2dSubsampling8(model)
+ self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+ else:
+ self.model[-1] = get_pos_emb(model[-1], max_seq_len)
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ return self.model(x)
+ else:
+ return self.model(x, mask)
+
+
+def _pre_hook(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+):
+ """Perform pre-hook in load_state_dict for backward compatibility.
+
+ Note:
+ We saved self.pe until v.0.5.2 but we have omitted it later.
+ Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
+
+ """
+ k = prefix + "pe"
+ if k in state_dict:
+ state_dict.pop(k)
+
+
+class OnnxPositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_seq_len (int): Maximum input length.
+ reverse (bool): Whether to reverse the input position. Only for
+ the class LegacyRelPositionalEncoding. We remove it in the current
+ class RelPositionalEncoding.
+ """
+
+ def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
+ """Construct an PositionalEncoding object."""
+ super(OnnxPositionalEncoding, self).__init__()
+ self.d_model = model.d_model
+ self.reverse = reverse
+ self.max_seq_len = max_seq_len
+ self.xscale = math.sqrt(self.d_model)
+ self._register_load_state_dict_pre_hook(_pre_hook)
+ self.pe = model.pe
+ self.use_cache = use_cache
+ self.model = model
+ if self.use_cache:
+ self.extend_pe()
+ else:
+ self.div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+
+ def extend_pe(self):
+ """Reset the positional encodings."""
+ pe_length = len(self.pe[0])
+ if self.max_seq_len < pe_length:
+ self.pe = self.pe[:, : self.max_seq_len]
+ else:
+ self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
+ self.pe = self.model.pe
+
+ def _add_pe(self, x):
+ """Computes positional encoding"""
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+
+ x = x * self.xscale
+ x[:, :, 0::2] += torch.sin(position * self.div_term)
+ x[:, :, 1::2] += torch.cos(position * self.div_term)
+ return x
+
+ def forward(self, x: torch.Tensor):
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ if self.use_cache:
+ x = x * self.xscale + self.pe[:, : x.size(1)]
+ else:
+ x = self._add_pe(x)
+ return x
+
+
+class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
+ """Scaled positional encoding module.
+
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
+
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_seq_len (int): Maximum input length.
+
+ """
+
+ def __init__(self, model, max_seq_len=512, use_cache=True):
+ """Initialize class."""
+ super().__init__(model, max_seq_len, use_cache=use_cache)
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ self.alpha.data = torch.tensor(1.0)
+
+ def _add_pe(self, x):
+ """Computes positional encoding"""
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+
+ x = x * self.alpha
+ x[:, :, 0::2] += torch.sin(position * self.div_term)
+ x[:, :, 1::2] += torch.cos(position * self.div_term)
+ return x
+
+ def forward(self, x):
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+
+ """
+ if self.use_cache:
+ x = x + self.alpha * self.pe[:, : x.size(1)]
+ else:
+ x = self._add_pe(x)
+ return x
+
+
+class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
+ """Relative positional encoding module (old version).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_seq_len (int): Maximum input length.
+
+ """
+
+ def __init__(self, model, max_seq_len=512, use_cache=True):
+ """Initialize class."""
+ super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
+
+ def _get_pe(self, x):
+ """Computes positional encoding"""
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+
+ pe = torch.zeros(x.shape)
+ pe[:, :, 0::2] += torch.sin(position * self.div_term)
+ pe[:, :, 1::2] += torch.cos(position * self.div_term)
+ return pe
+
+ def forward(self, x):
+ """Compute positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+
+ """
+ x = x * self.xscale
+ if self.use_cache:
+ pos_emb = self.pe[:, : x.size(1)]
+ else:
+ pos_emb = self._get_pe(x)
+ return x, pos_emb
+
+
+class OnnxRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module (new implementation).
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_seq_len (int): Maximum input length.
+ """
+
+ def __init__(self, model, max_seq_len=512, use_cache=True):
+ """Construct an PositionalEncoding object."""
+ super(OnnxRelPositionalEncoding, self).__init__()
+ self.d_model = model.d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.pe = None
+ self.use_cache = use_cache
+ if self.use_cache:
+ self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
+ else:
+ self.div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vecotr and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
+ pe_positive = torch.zeros(x.size(1), self.d_model)
+ pe_negative = torch.zeros(x.size(1), self.d_model)
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+
+ # Reserve the order of positive indices and concat both positive and
+ # negative indices. This is used to support the shifting trick
+ # as in https://arxiv.org/abs/1901.02860
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def _get_pe(self, x):
+ pe_positive = torch.zeros(x.size(1), self.d_model)
+ pe_negative = torch.zeros(x.size(1), self.d_model)
+ theta = (
+ torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
+ )
+ pe_positive[:, 0::2] = torch.sin(theta)
+ pe_positive[:, 1::2] = torch.cos(theta)
+ pe_negative[:, 0::2] = -1 * torch.sin(theta)
+ pe_negative[:, 1::2] = torch.cos(theta)
+
+ # Reserve the order of positive indices and concat both positive and
+ # negative indices. This is used to support the shifting trick
+ # as in https://arxiv.org/abs/1901.02860
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+ return torch.cat([pe_positive, pe_negative], dim=1)
+
+ def forward(self, x: torch.Tensor, use_cache=True):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ x = x * self.xscale
+ if self.use_cache:
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
+ ]
+ else:
+ pos_emb = self._get_pe(x)
+ return x, pos_emb
+
+
+class OnnxStreamPositionalEncoding(torch.nn.Module):
+ """Streaming Positional encoding."""
+
+ def __init__(self, model, max_seq_len=5000, use_cache=True):
+ """Construct an PositionalEncoding object."""
+ super(StreamPositionalEncoding, self).__init__()
+ self.use_cache = use_cache
+ self.d_model = model.d_model
+ self.xscale = model.xscale
+ self.pe = model.pe
+ self.use_cache = use_cache
+ self.max_seq_len = max_seq_len
+ if self.use_cache:
+ self.extend_pe()
+ else:
+ self.div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ self._register_load_state_dict_pre_hook(_pre_hook)
+
+ def extend_pe(self):
+ """Reset the positional encodings."""
+ pe_length = len(self.pe[0])
+ if self.max_seq_len < pe_length:
+ self.pe = self.pe[:, : self.max_seq_len]
+ else:
+ self.model.extend_pe(self.max_seq_len)
+ self.pe = self.model.pe
+
+ def _add_pe(self, x, start_idx):
+ position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
+ x = x * self.xscale
+ x[:, :, 0::2] += torch.sin(position * self.div_term)
+ x[:, :, 1::2] += torch.cos(position * self.div_term)
+ return x
+
+ def forward(self, x: torch.Tensor, start_idx: int = 0):
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+
+ """
+ if self.use_cache:
+ return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
+ else:
+ return self._add_pe(x, start_idx)
diff --git a/funasr/export/models/language_models/seq_rnn.py b/funasr/export/models/language_models/seq_rnn.py
new file mode 100644
index 0000000..ecff4b8
--- /dev/null
+++ b/funasr/export/models/language_models/seq_rnn.py
@@ -0,0 +1,84 @@
+import os
+
+import torch
+import torch.nn as nn
+
+class SequentialRNNLM(nn.Module):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.encoder = model.encoder
+ self.rnn = model.rnn
+ self.rnn_type = model.rnn_type
+ self.decoder = model.decoder
+ self.nlayers = model.nlayers
+ self.nhid = model.nhid
+ self.model_name = "seq_rnnlm"
+
+ def forward(self, y, hidden1, hidden2=None):
+ # batch_score function.
+ emb = self.encoder(y)
+ if self.rnn_type == "LSTM":
+ output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
+ else:
+ output, hidden1 = self.rnn(emb, hidden1)
+
+ decoded = self.decoder(
+ output.contiguous().view(output.size(0) * output.size(1), output.size(2))
+ )
+ if self.rnn_type == "LSTM":
+ return (
+ decoded.view(output.size(0), output.size(1), decoded.size(1)),
+ hidden1,
+ hidden2,
+ )
+ else:
+ return (
+ decoded.view(output.size(0), output.size(1), decoded.size(1)),
+ hidden1,
+ )
+
+ def get_dummy_inputs(self):
+ tgt = torch.LongTensor([0, 1]).unsqueeze(0)
+ hidden = torch.randn(self.nlayers, 1, self.nhid)
+ if self.rnn_type == "LSTM":
+ return (tgt, hidden, hidden)
+ else:
+ return (tgt, hidden)
+
+ def get_input_names(self):
+ if self.rnn_type == "LSTM":
+ return ["x", "in_hidden1", "in_hidden2"]
+ else:
+ return ["x", "in_hidden1"]
+
+ def get_output_names(self):
+ if self.rnn_type == "LSTM":
+ return ["y", "out_hidden1", "out_hidden2"]
+ else:
+ return ["y", "out_hidden1"]
+
+ def get_dynamic_axes(self):
+ ret = {
+ "x": {0: "x_batch", 1: "x_length"},
+ "y": {0: "y_batch"},
+ "in_hidden1": {1: "hidden1_batch"},
+ "out_hidden1": {1: "out_hidden1_batch"},
+ }
+ if self.rnn_type == "LSTM":
+ ret.update(
+ {
+ "in_hidden2": {1: "hidden2_batch"},
+ "out_hidden2": {1: "out_hidden2_batch"},
+ }
+ )
+ return ret
+
+ def get_model_config(self, path):
+ return {
+ "use_lm": True,
+ "model_path": os.path.join(path, f"{self.model_name}.onnx"),
+ "lm_type": "SequentialRNNLM",
+ "rnn_type": self.rnn_type,
+ "nhid": self.nhid,
+ "nlayers": self.nlayers,
+ }
diff --git a/funasr/export/models/language_models/subsampling.py b/funasr/export/models/language_models/subsampling.py
new file mode 100644
index 0000000..e71e127
--- /dev/null
+++ b/funasr/export/models/language_models/subsampling.py
@@ -0,0 +1,185 @@
+"""Subsampling layer definition."""
+
+import torch
+
+
+class OnnxConv2dSubsampling(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+
+ """
+
+ def __init__(self, model):
+ """Construct an Conv2dSubsampling object."""
+ super().__init__()
+ self.conv = model.conv
+ self.out = model.out
+
+ def forward(self, x, x_mask):
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if x_mask is None:
+ return x, None
+ return x, x_mask[:, :-2:2][:, :-2:2]
+
+ def __getitem__(self, key):
+ """Get item.
+
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
+ return the positioning encoding.
+
+ """
+ if key != -1:
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+ return self.out[key]
+
+
+class OnnxConv2dSubsampling2(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/2 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+
+ """
+
+ def __init__(self, model):
+ """Construct an Conv2dSubsampling object."""
+ super().__init__()
+ self.conv = model.conv
+ self.out = model.out
+
+ def forward(self, x, x_mask):
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 2.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 2.
+
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if x_mask is None:
+ return x, None
+ return x, x_mask[:, :-2:2][:, :-2:1]
+
+ def __getitem__(self, key):
+ """Get item.
+
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
+ return the positioning encoding.
+
+ """
+ if key != -1:
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+ return self.out[key]
+
+
+class OnnxConv2dSubsampling6(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/6 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+
+ """
+
+ def __init__(self, model):
+ """Construct an Conv2dSubsampling object."""
+ super().__init__()
+ self.conv = model.conv
+ self.out = model.out
+
+ def forward(self, x, x_mask):
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 6.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 6.
+
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if x_mask is None:
+ return x, None
+ return x, x_mask[:, :-2:2][:, :-4:3]
+
+
+class OnnxConv2dSubsampling8(torch.nn.Module):
+ """Convolutional 2D subsampling (to 1/8 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+
+ """
+
+ def __init__(self, model):
+ """Construct an Conv2dSubsampling object."""
+ super().__init__()
+ self.conv = model.conv
+ self.out = model.out
+
+ def forward(self, x, x_mask):
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 8.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 8.
+
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ if x_mask is None:
+ return x, None
+ return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]
diff --git a/funasr/export/models/language_models/transformer.py b/funasr/export/models/language_models/transformer.py
new file mode 100644
index 0000000..ebf0574
--- /dev/null
+++ b/funasr/export/models/language_models/transformer.py
@@ -0,0 +1,110 @@
+import os
+
+import torch
+import torch.nn as nn
+from funasr.modules.vgg2l import import VGG2L
+from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.subsampling import (
+ Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
+
+from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
+from funasr.export.models.language_models.embed import Embedding
+from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
+
+from funasr.export.utils.torch_function import MakePadMask
+
+class TransformerLM(nn.Module, AbsExportModel):
+ def __init__(self, model, max_seq_len=512, **kwargs):
+ super().__init__()
+ self.embed = Embedding(model.embed, max_seq_len)
+ self.encoder = model.encoder
+ self.decoder = model.decoder
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ # replace multihead attention module into customized module.
+ for i, d in enumerate(self.encoder.encoders):
+ # d is EncoderLayer
+ if isinstance(d.self_attn, MultiHeadedAttention):
+ d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
+ self.encoder.encoders[i] = OnnxEncoderLayer(d)
+
+ self.model_name = "transformer_lm"
+ self.num_heads = self.encoder.encoders[0].self_attn.h
+ self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
+
+ def prepare_mask(self, mask):
+ if len(mask.shape) == 2:
+ mask = mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask = mask[:, None, :]
+ mask = 1 - mask
+ return mask * -10000.0
+
+ def forward(self, y, cache):
+ feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
+ mask = self.make_pad_mask(feats_length) # (B, T)
+ mask = (y != 0) * mask
+
+ xs = self.embed(y)
+ # forward_one_step of Encoder
+ if isinstance(
+ self.encoder.embed,
+ (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
+ ):
+ xs, mask = self.encoder.embed(xs, mask)
+ else:
+ xs = self.encoder.embed(xs)
+
+ new_cache = []
+ mask = self.prepare_mask(mask)
+ for c, e in zip(cache, self.encoder.encoders):
+ xs, mask = e(xs, mask, c)
+ new_cache.append(xs)
+
+ if self.encoder.normalize_before:
+ xs = self.encoder.after_norm(xs)
+
+ h = self.decoder(xs[:, -1])
+ return h, new_cache
+
+ def get_dummy_inputs(self):
+ tgt = torch.LongTensor([1]).unsqueeze(0)
+ cache = [
+ torch.zeros((1, 1, self.encoder.encoders[0].size))
+ for _ in range(len(self.encoder.encoders))
+ ]
+ return (tgt, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
+
+ def get_output_names(self):
+ return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
+
+ def get_dynamic_axes(self):
+ ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
+ ret.update(
+ {
+ "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
+ for d in range(len(self.encoder.encoders))
+ }
+ )
+ ret.update(
+ {
+ "out_cache_%d"
+ % d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
+ for d in range(len(self.encoder.encoders))
+ }
+ )
+ return ret
+
+ def get_model_config(self, path):
+ return {
+ "use_lm": True,
+ "model_path": os.path.join(path, f"{self.model_name}.onnx"),
+ "lm_type": "TransformerLM",
+ "odim": self.encoder.encoders[0].size,
+ "nlayers": len(self.encoder.encoders),
+ }
--
Gitblit v1.9.1