From f2d8ded57f6403696001d39dd07a1396e5a03658 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 01:24:43 +0800
Subject: [PATCH] export onnx (#1455)
---
funasr/models/paraformer/decoder.py | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 171 insertions(+), 0 deletions(-)
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index ad321e4..ce018f4 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -581,6 +581,177 @@
return y, new_cache
+class DecoderLayerSANMExport(torch.nn.Module):
+
+ def __init__(
+ self,
+ model
+ ):
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.src_attn = model.src_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
+ self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
+ self.size = model.size
+
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ if self.src_attn is not None:
+ residual = x
+ x = self.norm3(x)
+ x = residual + self.src_attn(x, memory, memory_mask)
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
+class ParaformerSANMDecoderExport(torch.nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True, ):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ from funasr.utils.torch_function import MakePadMask
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.model.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, ys_in_lens
+
+ def get_dummy_inputs(self, enc_size):
+ tgt = torch.LongTensor([0]).unsqueeze(0)
+ memory = torch.randn(1, 100, enc_size)
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ cache = [
+ torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+ for _ in range(cache_num)
+ ]
+ return (tgt, memory, pre_acoustic_embeds, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ + ['cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['y'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'tgt': {
+ 0: 'tgt_batch',
+ 1: 'tgt_length'
+ },
+ 'memory': {
+ 0: 'memory_batch',
+ 1: 'memory_length'
+ },
+ 'pre_acoustic_embeds': {
+ 0: 'acoustic_embeds_batch',
+ 1: 'acoustic_embeds_length',
+ }
+ }
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ ret.update({
+ 'cache_%d' % d: {
+ 0: 'cache_%d_batch' % d,
+ 2: 'cache_%d_length' % d
+ }
+ for d in range(cache_num)
+ })
+ return ret
+
+
@tables.register("decoder_classes", "ParaformerSANDecoder")
class ParaformerSANDecoder(BaseTransformerDecoder):
--
Gitblit v1.9.1