From a030ff0f85fd6b1cc2a1d443d2fcfb11ccb1aa8f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 29 三月 2023 21:15:55 +0800
Subject: [PATCH] export
---
funasr/export/models/vad_realtime_transformer.py | 79 +++++++++++++
funasr/export/models/encoder/sanm_encoder.py | 99 ++++++++++++++++
funasr/export/models/target_delay_transformer.py | 132 +++++++++++-----------
funasr/export/models/__init__.py | 4
4 files changed, 248 insertions(+), 66 deletions(-)
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index a341338..62ee723 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -6,6 +6,8 @@
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
+from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.export.models.vad_realtime_transformer import VadRealtimeTransformer as VadRealtimeTransformer_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
@@ -17,5 +19,7 @@
elif isinstance(model, ESPnetPunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer):
return TargetDelayTransformer_export(model.punc_model, **export_config)
+ elif isinstance(model.punc_model, VadRealtimeTransformer):
+ return VadRealtimeTransformer_export(model.punc_model, **export_config)
else:
raise "Funasr does not support the given model type currently."
diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py
index 8a50538..3b7b414 100644
--- a/funasr/export/models/encoder/sanm_encoder.py
+++ b/funasr/export/models/encoder/sanm_encoder.py
@@ -107,3 +107,102 @@
}
}
+
+
+class SANMVadEncoder(nn.Module):
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='encoder',
+ onnx: bool = True,
+ ):
+ super().__init__()
+ self.embed = model.embed
+ self.model = model
+ self.feats_dim = feats_dim
+ self._output_size = model._output_size
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ if hasattr(model, 'encoders0'):
+ for i, d in enumerate(self.model.encoders0):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
+ if isinstance(d.feed_forward, PositionwiseFeedForward):
+ d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
+ self.model.encoders0[i] = EncoderLayerSANM_export(d)
+
+ for i, d in enumerate(self.model.encoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
+ if isinstance(d.feed_forward, PositionwiseFeedForward):
+ d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
+ self.model.encoders[i] = EncoderLayerSANM_export(d)
+
+ self.model_name = model_name
+ self.num_heads = model.encoders[0].self_attn.h
+ self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ speech = speech * self._output_size ** 0.5
+ mask = self.make_pad_mask(speech_lengths)
+ mask = self.prepare_mask(mask)
+ if self.embed is None:
+ xs_pad = speech
+ else:
+ xs_pad = self.embed(speech)
+
+ encoder_outs = self.model.encoders0(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ encoder_outs = self.model.encoders(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ xs_pad = self.model.after_norm(xs_pad)
+
+ return xs_pad, speech_lengths
+
+ def get_output_size(self):
+ return self.model.encoders[0].size
+
+ def get_dummy_inputs(self):
+ feats = torch.randn(1, 100, self.feats_dim)
+ return (feats)
+
+ def get_input_names(self):
+ return ['feats']
+
+ def get_output_names(self):
+ return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
+
+ def get_dynamic_axes(self):
+ return {
+ 'feats': {
+ 1: 'feats_length'
+ },
+ 'encoder_out': {
+ 1: 'enc_out_length'
+ },
+ 'predictor_weight': {
+ 1: 'pre_out_length'
+ }
+
+ }
diff --git a/funasr/export/models/target_delay_transformer.py b/funasr/export/models/target_delay_transformer.py
index 0a2586c..fd90835 100644
--- a/funasr/export/models/target_delay_transformer.py
+++ b/funasr/export/models/target_delay_transformer.py
@@ -28,7 +28,7 @@
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
- self.model = model
+ # self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
@@ -46,71 +46,71 @@
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
- class TargetDelayTransformer(nn.Module):
-
- def __init__(
- self,
- model,
- max_seq_len=512,
- model_name='punc_model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- self.embed = model.embed
- self.decoder = model.decoder
- self.model = model
- self.feats_dim = self.embed.embedding_dim
- self.num_embeddings = self.embed.num_embeddings
- self.model_name = model_name
-
- if isinstance(model.encoder, SANMEncoder):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- else:
- assert False, "Only support samn encode."
-
- def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
- """Compute loss value from buffer sequences.
-
- Args:
- input (torch.Tensor): Input ids. (batch, len)
- hidden (torch.Tensor): Target ids. (batch, len)
-
- """
- x = self.embed(input)
- # mask = self._target_mask(input)
- h, _ = self.encoder(x, text_lengths)
- y = self.decoder(h)
- return y
-
- def get_dummy_inputs(self):
- length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
- text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
- return (text_indexes, text_lengths)
-
- def get_input_names(self):
- return ['input', 'text_lengths']
-
- def get_output_names(self):
- return ['logits']
-
- def get_dynamic_axes(self):
- return {
- 'input': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'text_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- }
+ # class TargetDelayTransformer(nn.Module):
+ #
+ # def __init__(
+ # self,
+ # model,
+ # max_seq_len=512,
+ # model_name='punc_model',
+ # **kwargs,
+ # ):
+ # super().__init__()
+ # onnx = False
+ # if "onnx" in kwargs:
+ # onnx = kwargs["onnx"]
+ # self.embed = model.embed
+ # self.decoder = model.decoder
+ # self.model = model
+ # self.feats_dim = self.embed.embedding_dim
+ # self.num_embeddings = self.embed.num_embeddings
+ # self.model_name = model_name
+ #
+ # if isinstance(model.encoder, SANMEncoder):
+ # self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+ # else:
+ # assert False, "Only support samn encode."
+ #
+ # def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ # """Compute loss value from buffer sequences.
+ #
+ # Args:
+ # input (torch.Tensor): Input ids. (batch, len)
+ # hidden (torch.Tensor): Target ids. (batch, len)
+ #
+ # """
+ # x = self.embed(input)
+ # # mask = self._target_mask(input)
+ # h, _ = self.encoder(x, text_lengths)
+ # y = self.decoder(h)
+ # return y
+ #
+ # def get_dummy_inputs(self):
+ # length = 120
+ # text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
+ # text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
+ # return (text_indexes, text_lengths)
+ #
+ # def get_input_names(self):
+ # return ['input', 'text_lengths']
+ #
+ # def get_output_names(self):
+ # return ['logits']
+ #
+ # def get_dynamic_axes(self):
+ # return {
+ # 'input': {
+ # 0: 'batch_size',
+ # 1: 'feats_length'
+ # },
+ # 'text_lengths': {
+ # 0: 'batch_size',
+ # },
+ # 'logits': {
+ # 0: 'batch_size',
+ # 1: 'logits_length'
+ # },
+ # }
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py
new file mode 100644
index 0000000..44583d8
--- /dev/null
+++ b/funasr/export/models/vad_realtime_transformer.py
@@ -0,0 +1,79 @@
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
+from funasr.punctuation.abs_model import AbsPunctuation
+from funasr.punctuation.sanm_encoder import SANMVadEncoder
+from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
+
+class VadRealtimeTransformer(AbsPunctuation):
+
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ model_name='punc_model',
+ **kwargs,
+ ):
+ super().__init__()
+
+
+ self.embed = model.embed
+ if isinstance(model.encoder, SANMVadEncoder):
+ self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
+ else:
+ assert False, "Only support samn encode."
+ # self.encoder = model.encoder
+ self.decoder = model.decoder
+
+
+
+ def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
+ vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ """Compute loss value from buffer sequences.
+
+ Args:
+ input (torch.Tensor): Input ids. (batch, len)
+ hidden (torch.Tensor): Target ids. (batch, len)
+
+ """
+ x = self.embed(input)
+ # mask = self._target_mask(input)
+ h, _, _ = self.encoder(x, text_lengths, vad_indexes)
+ y = self.decoder(h)
+ return y
+
+ def with_vad(self):
+ return True
+
+ def get_dummy_inputs(self):
+ length = 120
+ text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
+ text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+ return (text_indexes, text_lengths)
+
+ def get_input_names(self):
+ return ['input', 'text_lengths']
+
+ def get_output_names(self):
+ return ['logits']
+
+ def get_dynamic_axes(self):
+ return {
+ 'input': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'text_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
--
Gitblit v1.9.1