From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/export/models/e2e_asr_paraformer.py |  151 +++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 149 insertions(+), 2 deletions(-)

diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
index 52ad320..5697b77 100644
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ b/funasr/export/models/e2e_asr_paraformer.py
@@ -4,7 +4,7 @@
 
 from funasr.export.utils.torch_function import MakePadMask
 from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.sanm_encoder import SANMEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
 from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
 from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
@@ -15,6 +15,7 @@
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
 from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
+from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
 
 
 class Paraformer(nn.Module):
@@ -216,4 +217,150 @@
                 0: 'batch_size',
                 1: 'alphas_length'
             },
-        }
\ No newline at end of file
+        }
+
+
+class ParaformerOnline_encoder_predictor(nn.Module):
+    """
+    Author: Speech Lab, Alibaba Group, China
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+    
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name='model',
+        **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+        if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
+            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+        elif isinstance(model.encoder, ConformerEncoder):
+            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+        if isinstance(model.predictor, CifPredictorV2):
+            self.predictor = CifPredictorV2_export(model.predictor)
+        
+        self.feats_dim = feats_dim
+        self.model_name = model_name
+        
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+    
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
+        # batch = to_device(batch, device=self.device)
+        
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        alphas, _ = self.predictor.forward_cnn(enc, mask)
+        
+        return enc, enc_len, alphas
+    
+    def get_dummy_inputs(self):
+        speech = torch.randn(2, 30, self.feats_dim)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        return (speech, speech_lengths)
+    
+    def get_input_names(self):
+        return ['speech', 'speech_lengths']
+    
+    def get_output_names(self):
+        return ['enc', 'enc_len', 'alphas']
+    
+    def get_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'enc': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'enc_len': {
+                0: 'batch_size',
+            },
+            'alphas': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+        }
+
+
+class ParaformerOnline_decoder(nn.Module):
+    """
+    Author: Speech Lab, Alibaba Group, China
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+    
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name='model',
+        **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+
+        if isinstance(model.decoder, ParaformerDecoderSAN):
+            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
+        elif isinstance(model.decoder, ParaformerSANMDecoder):
+            self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
+        
+        self.feats_dim = feats_dim
+        self.model_name = model_name
+        self.enc_size = model.encoder._output_size
+        
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+    
+    def forward(
+        self,
+        enc: torch.Tensor,
+        enc_len: torch.Tensor,
+        acoustic_embeds: torch.Tensor,
+        acoustic_embeds_len: torch.Tensor,
+        *args,
+    ):
+        decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
+        sample_ids = decoder_out.argmax(dim=-1)
+        
+        return decoder_out, sample_ids, out_caches
+    
+    def get_dummy_inputs(self, ):
+        dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
+        return dummy_inputs
+
+    def get_input_names(self):
+        
+        return self.decoder.get_input_names()
+
+    def get_output_names(self):
+        
+        return self.decoder.get_output_names()
+
+    def get_dynamic_axes(self):
+        return self.decoder.get_dynamic_axes()

--
Gitblit v1.9.1