From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/export/models/e2e_asr_paraformer.py | 151 +++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 149 insertions(+), 2 deletions(-)
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
index 52ad320..5697b77 100644
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ b/funasr/export/models/e2e_asr_paraformer.py
@@ -4,7 +4,7 @@
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.sanm_encoder import SANMEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
@@ -15,6 +15,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
+from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
class Paraformer(nn.Module):
@@ -216,4 +217,150 @@
0: 'batch_size',
1: 'alphas_length'
},
- }
\ No newline at end of file
+ }
+
+
+class ParaformerOnline_encoder_predictor(nn.Module):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+ if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
+ self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+ elif isinstance(model.encoder, ConformerEncoder):
+ self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+ if isinstance(model.predictor, CifPredictorV2):
+ self.predictor = CifPredictorV2_export(model.predictor)
+
+ self.feats_dim = feats_dim
+ self.model_name = model_name
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
+ # batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ alphas, _ = self.predictor.forward_cnn(enc, mask)
+
+ return enc, enc_len, alphas
+
+ def get_dummy_inputs(self):
+ speech = torch.randn(2, 30, self.feats_dim)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+ def get_input_names(self):
+ return ['speech', 'speech_lengths']
+
+ def get_output_names(self):
+ return ['enc', 'enc_len', 'alphas']
+
+ def get_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'enc': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'enc_len': {
+ 0: 'batch_size',
+ },
+ 'alphas': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ }
+
+
+class ParaformerOnline_decoder(nn.Module):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+
+ if isinstance(model.decoder, ParaformerDecoderSAN):
+ self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
+ elif isinstance(model.decoder, ParaformerSANMDecoder):
+ self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
+
+ self.feats_dim = feats_dim
+ self.model_name = model_name
+ self.enc_size = model.encoder._output_size
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ def forward(
+ self,
+ enc: torch.Tensor,
+ enc_len: torch.Tensor,
+ acoustic_embeds: torch.Tensor,
+ acoustic_embeds_len: torch.Tensor,
+ *args,
+ ):
+ decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
+ sample_ids = decoder_out.argmax(dim=-1)
+
+ return decoder_out, sample_ids, out_caches
+
+ def get_dummy_inputs(self, ):
+ dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
+ return dummy_inputs
+
+ def get_input_names(self):
+
+ return self.decoder.get_input_names()
+
+ def get_output_names(self):
+
+ return self.decoder.get_output_names()
+
+ def get_dynamic_axes(self):
+ return self.decoder.get_dynamic_axes()
--
Gitblit v1.9.1