From f2d8ded57f6403696001d39dd07a1396e5a03658 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 01:24:43 +0800
Subject: [PATCH] export onnx (#1455)
---
funasr/models/sanm/encoder.py | 132 ++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 132 insertions(+), 0 deletions(-)
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index 069c527..561179b 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -456,3 +456,135 @@
return (xs_pad, intermediate_outs), None, None
return xs_pad, ilens, None
+class EncoderLayerSANMExport(nn.Module):
+ def __init__(
+ self,
+ model,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2
+ self.in_size = model.in_size
+ self.size = model.size
+
+ def forward(self, x, mask):
+
+ residual = x
+ x = self.norm1(x)
+ x = self.self_attn(x, mask)
+ if self.in_size == self.size:
+ x = x + residual
+ residual = x
+ x = self.norm2(x)
+ x = self.feed_forward(x)
+ x = x + residual
+
+ return x, mask
+
+
+@tables.register("encoder_classes", "SANMEncoderExport")
+class SANMEncoderExport(nn.Module):
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='encoder',
+ onnx: bool = True,
+ ):
+ super().__init__()
+ self.embed = model.embed
+ if isinstance(self.embed, StreamSinusoidalPositionEncoder):
+ self.embed = None
+ self.model = model
+ self.feats_dim = feats_dim
+ self._output_size = model._output_size
+
+ from funasr.utils.torch_function import MakePadMask
+ from funasr.utils.torch_function import sequence_mask
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
+ if hasattr(model, 'encoders0'):
+ for i, d in enumerate(self.model.encoders0):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+ self.model.encoders0[i] = EncoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.encoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+ self.model.encoders[i] = EncoderLayerSANMExport(d)
+
+ self.model_name = model_name
+ self.num_heads = model.encoders[0].self_attn.h
+ self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ online: bool = False
+ ):
+ if not online:
+ speech = speech * self._output_size ** 0.5
+ mask = self.make_pad_mask(speech_lengths)
+ mask = self.prepare_mask(mask)
+ if self.embed is None:
+ xs_pad = speech
+ else:
+ xs_pad = self.embed(speech)
+
+ encoder_outs = self.model.encoders0(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ encoder_outs = self.model.encoders(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ xs_pad = self.model.after_norm(xs_pad)
+
+ return xs_pad, speech_lengths
+
+ def get_output_size(self):
+ return self.model.encoders[0].size
+
+ def get_dummy_inputs(self):
+ feats = torch.randn(1, 100, self.feats_dim)
+ return (feats)
+
+ def get_input_names(self):
+ return ['feats']
+
+ def get_output_names(self):
+ return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
+
+ def get_dynamic_axes(self):
+ return {
+ 'feats': {
+ 1: 'feats_length'
+ },
+ 'encoder_out': {
+ 1: 'enc_out_length'
+ },
+ 'predictor_weight': {
+ 1: 'pre_out_length'
+ }
+
+ }
+
--
Gitblit v1.9.1