| New file |
| | |
| | | import logging |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder as TransformerDecoder_export |
| | | |
| | | |
| | | class Conformer(nn.Module): |
| | | """ |
| | | export conformer into onnx format |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | output_size=2048, |
| | | model_name='model', |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | onnx = False |
| | | if "onnx" in kwargs: |
| | | onnx = kwargs["onnx"] |
| | | if isinstance(model.encoder, ConformerEncoder): |
| | | self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) |
| | | elif isinstance(model.decoder, TransformerDecoder): |
| | | self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx) |
| | | |
| | | self.feats_dim = feats_dim |
| | | self.output_size = output_size |
| | | self.model_name = model_name |
| | | |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | ): |
| | | # a. To device |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | | # batch = to_device(batch, device=self.device) |
| | | |
| | | enc, enc_len = self.encoder(**batch) |
| | | mask = self.make_pad_mask(enc_len)[:, None, :] |
| | | |
| | | # fill the decoder input |
| | | enc_size = self.encoder.output_size |
| | | pre_acoustic_embeds = torch.randn(1, 1, enc_size) |
| | | cache_num = len(self.model.decoder) |
| | | cache = [ |
| | | torch.zeros((1, self.decoder.size, self.decoder.self_attn.kernel_size)) |
| | | for _ in range(cache_num) |
| | | ] |
| | | |
| | | decoder_out, olens = self.decoder(enc, enc_len, pre_acoustic_embeds, cache) |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | # sample_ids = decoder_out.argmax(dim=-1) |
| | | |
| | | return decoder_out, olens |
| | | |
| | | def get_dummy_inputs(self): |
| | | speech = torch.randn(2, 30, self.feats_dim) |
| | | speech_lengths = torch.tensor([6, 30], dtype=torch.int32) |
| | | return (speech, speech_lengths) |
| | | |
| | | def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"): |
| | | import numpy as np |
| | | fbank = np.loadtxt(txt_file) |
| | | fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32) |
| | | speech = torch.from_numpy(fbank[None, :, :].astype(np.float32)) |
| | | speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32)) |
| | | return (speech, speech_lengths) |
| | | |
| | | def get_input_names(self): |
| | | return ['speech', 'speech_lengths'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits', 'token_num'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'speech': { |
| | | 0: 'batch_size', |
| | | 1: 'feats_length' |
| | | }, |
| | | 'speech_lengths': { |
| | | 0: 'batch_size', |
| | | }, |
| | | 'logits': { |
| | | 0: 'batch_size', |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| New file |
| | |
| | | """Positional Encoding Module.""" |
| | | |
| | | import math |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | from funasr.modules.embedding import ( |
| | | LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding, |
| | | ScaledPositionalEncoding, StreamPositionalEncoding) |
| | | from funasr.modules.subsampling import ( |
| | | Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, |
| | | Conv2dSubsampling8) |
| | | from funasr.modules.subsampling_without_posenc import \ |
| | | Conv2dSubsamplingWOPosEnc |
| | | |
| | | from funasr.export.models.language_models.subsampling import ( |
| | | OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6, |
| | | OnnxConv2dSubsampling8) |
| | | |
| | | |
| | | def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True): |
| | | if isinstance(pos_emb, LegacyRelPositionalEncoding): |
| | | return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, ScaledPositionalEncoding): |
| | | return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, RelPositionalEncoding): |
| | | return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, PositionalEncoding): |
| | | return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, StreamPositionalEncoding): |
| | | return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or ( |
| | | isinstance(pos_emb, Conv2dSubsamplingWOPosEnc) |
| | | ): |
| | | return pos_emb |
| | | else: |
| | | raise ValueError("Embedding model is not supported.") |
| | | |
| | | |
| | | class Embedding(nn.Module): |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | super().__init__() |
| | | self.model = model |
| | | if not isinstance(model, nn.Embedding): |
| | | if isinstance(model, Conv2dSubsampling): |
| | | self.model = OnnxConv2dSubsampling(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | elif isinstance(model, Conv2dSubsampling2): |
| | | self.model = OnnxConv2dSubsampling2(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | elif isinstance(model, Conv2dSubsampling6): |
| | | self.model = OnnxConv2dSubsampling6(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | elif isinstance(model, Conv2dSubsampling8): |
| | | self.model = OnnxConv2dSubsampling8(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | else: |
| | | self.model[-1] = get_pos_emb(model[-1], max_seq_len) |
| | | |
| | | def forward(self, x, mask=None): |
| | | if mask is None: |
| | | return self.model(x) |
| | | else: |
| | | return self.model(x, mask) |
| | | |
| | | |
| | | def _pre_hook( |
| | | state_dict, |
| | | prefix, |
| | | local_metadata, |
| | | strict, |
| | | missing_keys, |
| | | unexpected_keys, |
| | | error_msgs, |
| | | ): |
| | | """Perform pre-hook in load_state_dict for backward compatibility. |
| | | |
| | | Note: |
| | | We saved self.pe until v.0.5.2 but we have omitted it later. |
| | | Therefore, we remove the item "pe" from `state_dict` for backward compatibility. |
| | | |
| | | """ |
| | | k = prefix + "pe" |
| | | if k in state_dict: |
| | | state_dict.pop(k) |
| | | |
| | | |
| | | class OnnxPositionalEncoding(torch.nn.Module): |
| | | """Positional encoding. |
| | | |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | reverse (bool): Whether to reverse the input position. Only for |
| | | the class LegacyRelPositionalEncoding. We remove it in the current |
| | | class RelPositionalEncoding. |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True): |
| | | """Construct an PositionalEncoding object.""" |
| | | super(OnnxPositionalEncoding, self).__init__() |
| | | self.d_model = model.d_model |
| | | self.reverse = reverse |
| | | self.max_seq_len = max_seq_len |
| | | self.xscale = math.sqrt(self.d_model) |
| | | self._register_load_state_dict_pre_hook(_pre_hook) |
| | | self.pe = model.pe |
| | | self.use_cache = use_cache |
| | | self.model = model |
| | | if self.use_cache: |
| | | self.extend_pe() |
| | | else: |
| | | self.div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | |
| | | def extend_pe(self): |
| | | """Reset the positional encodings.""" |
| | | pe_length = len(self.pe[0]) |
| | | if self.max_seq_len < pe_length: |
| | | self.pe = self.pe[:, : self.max_seq_len] |
| | | else: |
| | | self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len)) |
| | | self.pe = self.model.pe |
| | | |
| | | def _add_pe(self, x): |
| | | """Computes positional encoding""" |
| | | if self.reverse: |
| | | position = torch.arange( |
| | | x.size(1) - 1, -1, -1.0, dtype=torch.float32 |
| | | ).unsqueeze(1) |
| | | else: |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | |
| | | x = x * self.xscale |
| | | x[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | x[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return x |
| | | |
| | | def forward(self, x: torch.Tensor): |
| | | """Add positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | """ |
| | | if self.use_cache: |
| | | x = x * self.xscale + self.pe[:, : x.size(1)] |
| | | else: |
| | | x = self._add_pe(x) |
| | | return x |
| | | |
| | | |
| | | class OnnxScaledPositionalEncoding(OnnxPositionalEncoding): |
| | | """Scaled positional encoding module. |
| | | |
| | | See Sec. 3.2 https://arxiv.org/abs/1809.08895 |
| | | |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | """Initialize class.""" |
| | | super().__init__(model, max_seq_len, use_cache=use_cache) |
| | | self.alpha = torch.nn.Parameter(torch.tensor(1.0)) |
| | | |
| | | def reset_parameters(self): |
| | | """Reset parameters.""" |
| | | self.alpha.data = torch.tensor(1.0) |
| | | |
| | | def _add_pe(self, x): |
| | | """Computes positional encoding""" |
| | | if self.reverse: |
| | | position = torch.arange( |
| | | x.size(1) - 1, -1, -1.0, dtype=torch.float32 |
| | | ).unsqueeze(1) |
| | | else: |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | |
| | | x = x * self.alpha |
| | | x[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | x[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return x |
| | | |
| | | def forward(self, x): |
| | | """Add positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | |
| | | """ |
| | | if self.use_cache: |
| | | x = x + self.alpha * self.pe[:, : x.size(1)] |
| | | else: |
| | | x = self._add_pe(x) |
| | | return x |
| | | |
| | | |
| | | class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding): |
| | | """Relative positional encoding module (old version). |
| | | |
| | | Details can be found in https://github.com/espnet/espnet/pull/2816. |
| | | |
| | | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | | |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | """Initialize class.""" |
| | | super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache) |
| | | |
| | | def _get_pe(self, x): |
| | | """Computes positional encoding""" |
| | | if self.reverse: |
| | | position = torch.arange( |
| | | x.size(1) - 1, -1, -1.0, dtype=torch.float32 |
| | | ).unsqueeze(1) |
| | | else: |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | |
| | | pe = torch.zeros(x.shape) |
| | | pe[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | pe[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return pe |
| | | |
| | | def forward(self, x): |
| | | """Compute positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | torch.Tensor: Positional embedding tensor (1, time, `*`). |
| | | |
| | | """ |
| | | x = x * self.xscale |
| | | if self.use_cache: |
| | | pos_emb = self.pe[:, : x.size(1)] |
| | | else: |
| | | pos_emb = self._get_pe(x) |
| | | return x, pos_emb |
| | | |
| | | |
| | | class OnnxRelPositionalEncoding(torch.nn.Module): |
| | | """Relative positional encoding module (new implementation). |
| | | Details can be found in https://github.com/espnet/espnet/pull/2816. |
| | | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | """Construct an PositionalEncoding object.""" |
| | | super(OnnxRelPositionalEncoding, self).__init__() |
| | | self.d_model = model.d_model |
| | | self.xscale = math.sqrt(self.d_model) |
| | | self.pe = None |
| | | self.use_cache = use_cache |
| | | if self.use_cache: |
| | | self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len)) |
| | | else: |
| | | self.div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | |
| | | def extend_pe(self, x): |
| | | """Reset the positional encodings.""" |
| | | if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1: |
| | | # self.pe contains both positive and negative parts |
| | | # the length of self.pe is 2 * input_len - 1 |
| | | if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| | | self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| | | return |
| | | # Suppose `i` means to the position of query vecotr and `j` means the |
| | | # position of key vector. We use position relative positions when keys |
| | | # are to the left (i>j) and negative relative positions otherwise (i<j). |
| | | pe_positive = torch.zeros(x.size(1), self.d_model) |
| | | pe_negative = torch.zeros(x.size(1), self.d_model) |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| | | |
| | | # Reserve the order of positive indices and concat both positive and |
| | | # negative indices. This is used to support the shifting trick |
| | | # as in https://arxiv.org/abs/1901.02860 |
| | | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | | pe_negative = pe_negative[1:].unsqueeze(0) |
| | | pe = torch.cat([pe_positive, pe_negative], dim=1) |
| | | self.pe = pe.to(device=x.device, dtype=x.dtype) |
| | | |
| | | def _get_pe(self, x): |
| | | pe_positive = torch.zeros(x.size(1), self.d_model) |
| | | pe_negative = torch.zeros(x.size(1), self.d_model) |
| | | theta = ( |
| | | torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term |
| | | ) |
| | | pe_positive[:, 0::2] = torch.sin(theta) |
| | | pe_positive[:, 1::2] = torch.cos(theta) |
| | | pe_negative[:, 0::2] = -1 * torch.sin(theta) |
| | | pe_negative[:, 1::2] = torch.cos(theta) |
| | | |
| | | # Reserve the order of positive indices and concat both positive and |
| | | # negative indices. This is used to support the shifting trick |
| | | # as in https://arxiv.org/abs/1901.02860 |
| | | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | | pe_negative = pe_negative[1:].unsqueeze(0) |
| | | return torch.cat([pe_positive, pe_negative], dim=1) |
| | | |
| | | def forward(self, x: torch.Tensor, use_cache=True): |
| | | """Add positional encoding. |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | """ |
| | | x = x * self.xscale |
| | | if self.use_cache: |
| | | pos_emb = self.pe[ |
| | | :, |
| | | self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), |
| | | ] |
| | | else: |
| | | pos_emb = self._get_pe(x) |
| | | return x, pos_emb |
| | | |
| | | |
| | | class OnnxStreamPositionalEncoding(torch.nn.Module): |
| | | """Streaming Positional encoding.""" |
| | | |
| | | def __init__(self, model, max_seq_len=5000, use_cache=True): |
| | | """Construct an PositionalEncoding object.""" |
| | | super(StreamPositionalEncoding, self).__init__() |
| | | self.use_cache = use_cache |
| | | self.d_model = model.d_model |
| | | self.xscale = model.xscale |
| | | self.pe = model.pe |
| | | self.use_cache = use_cache |
| | | self.max_seq_len = max_seq_len |
| | | if self.use_cache: |
| | | self.extend_pe() |
| | | else: |
| | | self.div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | self._register_load_state_dict_pre_hook(_pre_hook) |
| | | |
| | | def extend_pe(self): |
| | | """Reset the positional encodings.""" |
| | | pe_length = len(self.pe[0]) |
| | | if self.max_seq_len < pe_length: |
| | | self.pe = self.pe[:, : self.max_seq_len] |
| | | else: |
| | | self.model.extend_pe(self.max_seq_len) |
| | | self.pe = self.model.pe |
| | | |
| | | def _add_pe(self, x, start_idx): |
| | | position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | x = x * self.xscale |
| | | x[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | x[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return x |
| | | |
| | | def forward(self, x: torch.Tensor, start_idx: int = 0): |
| | | """Add positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | |
| | | """ |
| | | if self.use_cache: |
| | | return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)] |
| | | else: |
| | | return self._add_pe(x, start_idx) |
| New file |
| | |
| | | import os |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | class SequentialRNNLM(nn.Module): |
| | | def __init__(self, model, **kwargs): |
| | | super().__init__() |
| | | self.encoder = model.encoder |
| | | self.rnn = model.rnn |
| | | self.rnn_type = model.rnn_type |
| | | self.decoder = model.decoder |
| | | self.nlayers = model.nlayers |
| | | self.nhid = model.nhid |
| | | self.model_name = "seq_rnnlm" |
| | | |
| | | def forward(self, y, hidden1, hidden2=None): |
| | | # batch_score function. |
| | | emb = self.encoder(y) |
| | | if self.rnn_type == "LSTM": |
| | | output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2)) |
| | | else: |
| | | output, hidden1 = self.rnn(emb, hidden1) |
| | | |
| | | decoded = self.decoder( |
| | | output.contiguous().view(output.size(0) * output.size(1), output.size(2)) |
| | | ) |
| | | if self.rnn_type == "LSTM": |
| | | return ( |
| | | decoded.view(output.size(0), output.size(1), decoded.size(1)), |
| | | hidden1, |
| | | hidden2, |
| | | ) |
| | | else: |
| | | return ( |
| | | decoded.view(output.size(0), output.size(1), decoded.size(1)), |
| | | hidden1, |
| | | ) |
| | | |
| | | def get_dummy_inputs(self): |
| | | tgt = torch.LongTensor([0, 1]).unsqueeze(0) |
| | | hidden = torch.randn(self.nlayers, 1, self.nhid) |
| | | if self.rnn_type == "LSTM": |
| | | return (tgt, hidden, hidden) |
| | | else: |
| | | return (tgt, hidden) |
| | | |
| | | def get_input_names(self): |
| | | if self.rnn_type == "LSTM": |
| | | return ["x", "in_hidden1", "in_hidden2"] |
| | | else: |
| | | return ["x", "in_hidden1"] |
| | | |
| | | def get_output_names(self): |
| | | if self.rnn_type == "LSTM": |
| | | return ["y", "out_hidden1", "out_hidden2"] |
| | | else: |
| | | return ["y", "out_hidden1"] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | "x": {0: "x_batch", 1: "x_length"}, |
| | | "y": {0: "y_batch"}, |
| | | "in_hidden1": {1: "hidden1_batch"}, |
| | | "out_hidden1": {1: "out_hidden1_batch"}, |
| | | } |
| | | if self.rnn_type == "LSTM": |
| | | ret.update( |
| | | { |
| | | "in_hidden2": {1: "hidden2_batch"}, |
| | | "out_hidden2": {1: "out_hidden2_batch"}, |
| | | } |
| | | ) |
| | | return ret |
| | | |
| | | def get_model_config(self, path): |
| | | return { |
| | | "use_lm": True, |
| | | "model_path": os.path.join(path, f"{self.model_name}.onnx"), |
| | | "lm_type": "SequentialRNNLM", |
| | | "rnn_type": self.rnn_type, |
| | | "nhid": self.nhid, |
| | | "nlayers": self.nlayers, |
| | | } |
| New file |
| | |
| | | """Subsampling layer definition.""" |
| | | |
| | | import torch |
| | | |
| | | |
| | | class OnnxConv2dSubsampling(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/4 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 4. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 4. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-2:2] |
| | | |
| | | def __getitem__(self, key): |
| | | """Get item. |
| | | |
| | | When reset_parameters() is called, if use_scaled_pos_enc is used, |
| | | return the positioning encoding. |
| | | |
| | | """ |
| | | if key != -1: |
| | | raise NotImplementedError("Support only `-1` (for `reset_parameters`).") |
| | | return self.out[key] |
| | | |
| | | |
| | | class OnnxConv2dSubsampling2(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/2 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 2. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 2. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-2:1] |
| | | |
| | | def __getitem__(self, key): |
| | | """Get item. |
| | | |
| | | When reset_parameters() is called, if use_scaled_pos_enc is used, |
| | | return the positioning encoding. |
| | | |
| | | """ |
| | | if key != -1: |
| | | raise NotImplementedError("Support only `-1` (for `reset_parameters`).") |
| | | return self.out[key] |
| | | |
| | | |
| | | class OnnxConv2dSubsampling6(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/6 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 6. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 6. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-4:3] |
| | | |
| | | |
| | | class OnnxConv2dSubsampling8(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/8 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 8. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 8. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2] |
| New file |
| | |
| | | import os |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | from funasr.modules.vgg2l import import VGG2L |
| | | from funasr.modules.attention import MultiHeadedAttention |
| | | from funasr.modules.subsampling import ( |
| | | Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8) |
| | | |
| | | from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer |
| | | from funasr.export.models.language_models.embed import Embedding |
| | | from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | |
| | | class TransformerLM(nn.Module, AbsExportModel): |
| | | def __init__(self, model, max_seq_len=512, **kwargs): |
| | | super().__init__() |
| | | self.embed = Embedding(model.embed, max_seq_len) |
| | | self.encoder = model.encoder |
| | | self.decoder = model.decoder |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | # replace multihead attention module into customized module. |
| | | for i, d in enumerate(self.encoder.encoders): |
| | | # d is EncoderLayer |
| | | if isinstance(d.self_attn, MultiHeadedAttention): |
| | | d.self_attn = OnnxMultiHeadedAttention(d.self_attn) |
| | | self.encoder.encoders[i] = OnnxEncoderLayer(d) |
| | | |
| | | self.model_name = "transformer_lm" |
| | | self.num_heads = self.encoder.encoders[0].self_attn.h |
| | | self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features |
| | | |
| | | def prepare_mask(self, mask): |
| | | if len(mask.shape) == 2: |
| | | mask = mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask = mask[:, None, :] |
| | | mask = 1 - mask |
| | | return mask * -10000.0 |
| | | |
| | | def forward(self, y, cache): |
| | | feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long) |
| | | mask = self.make_pad_mask(feats_length) # (B, T) |
| | | mask = (y != 0) * mask |
| | | |
| | | xs = self.embed(y) |
| | | # forward_one_step of Encoder |
| | | if isinstance( |
| | | self.encoder.embed, |
| | | (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L), |
| | | ): |
| | | xs, mask = self.encoder.embed(xs, mask) |
| | | else: |
| | | xs = self.encoder.embed(xs) |
| | | |
| | | new_cache = [] |
| | | mask = self.prepare_mask(mask) |
| | | for c, e in zip(cache, self.encoder.encoders): |
| | | xs, mask = e(xs, mask, c) |
| | | new_cache.append(xs) |
| | | |
| | | if self.encoder.normalize_before: |
| | | xs = self.encoder.after_norm(xs) |
| | | |
| | | h = self.decoder(xs[:, -1]) |
| | | return h, new_cache |
| | | |
| | | def get_dummy_inputs(self): |
| | | tgt = torch.LongTensor([1]).unsqueeze(0) |
| | | cache = [ |
| | | torch.zeros((1, 1, self.encoder.encoders[0].size)) |
| | | for _ in range(len(self.encoder.encoders)) |
| | | ] |
| | | return (tgt, cache) |
| | | |
| | | def is_optimizable(self): |
| | | return True |
| | | |
| | | def get_input_names(self): |
| | | return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))] |
| | | |
| | | def get_output_names(self): |
| | | return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}} |
| | | ret.update( |
| | | { |
| | | "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d} |
| | | for d in range(len(self.encoder.encoders)) |
| | | } |
| | | ) |
| | | ret.update( |
| | | { |
| | | "out_cache_%d" |
| | | % d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d} |
| | | for d in range(len(self.encoder.encoders)) |
| | | } |
| | | ) |
| | | return ret |
| | | |
| | | def get_model_config(self, path): |
| | | return { |
| | | "use_lm": True, |
| | | "model_path": os.path.join(path, f"{self.model_name}.onnx"), |
| | | "lm_type": "TransformerLM", |
| | | "odim": self.encoder.encoders[0].size, |
| | | "nlayers": len(self.encoder.encoders), |
| | | } |