From 57f2a51f9ae2c7c9951f137f3d247cff47100944 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 27 二月 2023 16:55:06 +0800
Subject: [PATCH] onnx supports tiny and bicif paraformer

---
 funasr/export/models/e2e_asr_paraformer.py                            |  123 +++++++++
 funasr/export/models/predictor/cif.py                                 |  115 ++++++++
 funasr/export/models/modules/decoder_layer.py                         |   27 ++
 funasr/export/models/modules/encoder_layer.py                         |   54 ++++
 funasr/runtime/python/onnxruntime/demo.py                             |    6 
 funasr/utils/timestamp_tools.py                                       |    1 
 funasr/export/models/modules/multihead_att.py                         |  108 ++++++++
 funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py        |    1 
 funasr/export/models/decoder/transformer_decoder.py                   |  143 +++++++++++
 funasr/export/models/encoder/conformer_encoder.py                     |  106 ++++++++
 funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py |   63 ++++
 funasr/export/models/__init__.py                                      |   11 
 12 files changed, 742 insertions(+), 16 deletions(-)

diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 27a65af..0012377 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,10 +1,13 @@
-from funasr.models.e2e_asr_paraformer import Paraformer
+from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
 from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
+from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
 from funasr.models.e2e_uni_asr import UniASR
 
-def get_model(model, export_config=None):
 
-    if isinstance(model, Paraformer):
+def get_model(model, export_config=None):
+    if isinstance(model, BiCifParaformer):
+        return BiCifParaformer_export(model, **export_config)
+    elif isinstance(model, Paraformer):
         return Paraformer_export(model, **export_config)
     else:
-        raise "The model is not exist!"
\ No newline at end of file
+        raise "Funasr does not support the given model type currently."
\ No newline at end of file
diff --git a/funasr/export/models/decoder/transformer_decoder.py b/funasr/export/models/decoder/transformer_decoder.py
new file mode 100644
index 0000000..d70a3c7
--- /dev/null
+++ b/funasr/export/models/decoder/transformer_decoder.py
@@ -0,0 +1,143 @@
+import os
+from funasr.export import models
+
+import torch
+import torch.nn as nn
+
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+
+from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
+from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
+from funasr.modules.attention import MultiHeadedAttentionCrossAtt, MultiHeadedAttention
+from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
+from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
+from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
+from funasr.export.models.modules.decoder_layer import DecoderLayer as DecoderLayer_export
+
+
+class ParaformerDecoderSAN(nn.Module):
+    def __init__(self, model,
+                 max_seq_len=512,
+                 model_name='decoder',
+                 onnx: bool = True,):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+        self.model = model
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
+            # if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+            #     d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
+            if isinstance(d.src_attn, MultiHeadedAttention):
+                d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
+            self.model.decoders[i] = DecoderLayer_export(d)
+        
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+        
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+    
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask = self.model.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        x = self.after_norm(x)
+        x = self.output_layer(x)
+
+        return x, ys_in_lens
+
+
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        memory = torch.randn(1, 100, enc_size)
+        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        cache = [
+            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+            for _ in range(cache_num)
+        ]
+        return (tgt, memory, pre_acoustic_embeds, cache)
+
+    def is_optimizable(self):
+        return True
+
+    def get_input_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+               + ['cache_%d' % i for i in range(cache_num)]
+
+    def get_output_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['y'] \
+               + ['out_cache_%d' % i for i in range(cache_num)]
+
+    def get_dynamic_axes(self):
+        ret = {
+            'tgt': {
+                0: 'tgt_batch',
+                1: 'tgt_length'
+            },
+            'memory': {
+                0: 'memory_batch',
+                1: 'memory_length'
+            },
+            'pre_acoustic_embeds': {
+                0: 'acoustic_embeds_batch',
+                1: 'acoustic_embeds_length',
+            }
+        }
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        ret.update({
+            'cache_%d' % d: {
+                0: 'cache_%d_batch' % d,
+                2: 'cache_%d_length' % d
+            }
+            for d in range(cache_num)
+        })
+        return ret
+
+    def get_model_config(self, path):
+        return {
+            "dec_type": "XformerDecoder",
+            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
+            "n_layers": len(self.model.decoders) + len(self.model.decoders2),
+            "odim": self.model.decoders[0].size
+        }
\ No newline at end of file
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
index 5424a0a..0db61e0 100644
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ b/funasr/export/models/e2e_asr_paraformer.py
@@ -1,17 +1,21 @@
 import logging
-
-
 import torch
 import torch.nn as nn
 
 from funasr.export.utils.torch_function import MakePadMask
 from funasr.export.utils.torch_function import sequence_mask
 from funasr.models.encoder.sanm_encoder import SANMEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
 from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.models.predictor.cif import CifPredictorV2
+from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
+from funasr.models.predictor.cif import CifPredictorV2, CifPredictorV3
 from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
+from funasr.export.models.predictor.cif import CifPredictorV3 as CifPredictorV3_export
 from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
+from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
+
 
 class Paraformer(nn.Module):
     """
@@ -34,10 +38,14 @@
             onnx = kwargs["onnx"]
         if isinstance(model.encoder, SANMEncoder):
             self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+        elif isinstance(model.encoder, ConformerEncoder):
+            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
         if isinstance(model.predictor, CifPredictorV2):
             self.predictor = CifPredictorV2_export(model.predictor)
         if isinstance(model.decoder, ParaformerSANMDecoder):
             self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
+        elif isinstance(model.decoder, ParaformerDecoderSAN):
+            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
         
         self.feats_dim = feats_dim
         self.model_name = model_name
@@ -99,4 +107,113 @@
                 0: 'batch_size',
                 1: 'logits_length'
             },
+        }
+
+
+class BiCifParaformer(nn.Module):
+    """
+    Author: Speech Lab, Alibaba Group, China
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+
+    def __init__(
+            self,
+            model,
+            max_seq_len=512,
+            feats_dim=560,
+            model_name='model',
+            **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+        if isinstance(model.encoder, SANMEncoder):
+            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+        elif isinstance(model.encoder, ConformerEncoder):
+            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+        else:
+            logging.warning("Unsupported encoder type to export.")
+        if isinstance(model.predictor, CifPredictorV3):
+            self.predictor = CifPredictorV3_export(model.predictor)
+        else:
+            logging.warning("Wrong predictor type to export.")
+        if isinstance(model.decoder, ParaformerSANMDecoder):
+            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
+        elif isinstance(model.decoder, ParaformerDecoderSAN):
+            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
+        else:
+            logging.warning("Unsupported decoder type to export.")
+        
+        self.feats_dim = feats_dim
+        self.model_name = model_name
+
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        # batch = to_device(batch, device=self.device)
+    
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+        pre_token_length = pre_token_length.round().type(torch.int32)
+
+        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        
+        # get predicted timestamps
+        us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
+
+        return decoder_out, pre_token_length, us_alphas, us_cif_peak
+
+    def get_dummy_inputs(self):
+        speech = torch.randn(2, 30, self.feats_dim)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        return (speech, speech_lengths)
+
+    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
+        import numpy as np
+        fbank = np.loadtxt(txt_file)
+        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
+        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
+        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
+        return (speech, speech_lengths)
+
+    def get_input_names(self):
+        return ['speech', 'speech_lengths']
+
+    def get_output_names(self):
+        return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
+
+    def get_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+            'us_alphas': {
+                0: 'batch_size',
+                1: 'alphas_length'
+            },
+            'us_cif_peak': {
+                0: 'batch_size',
+                1: 'alphas_length'
+            },
         }
\ No newline at end of file
diff --git a/funasr/export/models/encoder/conformer_encoder.py b/funasr/export/models/encoder/conformer_encoder.py
new file mode 100644
index 0000000..9f22574
--- /dev/null
+++ b/funasr/export/models/encoder/conformer_encoder.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.modules.attention import MultiHeadedAttentionSANM
+from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
+from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
+from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as EncoderLayerConformer_export
+from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
+from funasr.export.models.encoder.sanm_encoder import SANMEncoder
+from funasr.modules.attention import RelPositionMultiHeadedAttention
+# from funasr.export.models.modules.multihead_att import RelPositionMultiHeadedAttention as RelPositionMultiHeadedAttention_export
+from funasr.export.models.modules.multihead_att import OnnxRelPosMultiHeadedAttention as RelPositionMultiHeadedAttention_export
+
+
+class ConformerEncoder(nn.Module):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name='encoder',
+        onnx: bool = True,
+    ):
+        super().__init__()
+        self.embed = model.embed
+        self.model = model
+        self.feats_dim = feats_dim
+        self._output_size = model._output_size
+
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        for i, d in enumerate(self.model.encoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
+            if isinstance(d.self_attn, RelPositionMultiHeadedAttention):
+                d.self_attn = RelPositionMultiHeadedAttention_export(d.self_attn)
+            if isinstance(d.feed_forward, PositionwiseFeedForward):
+                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
+            self.model.encoders[i] = EncoderLayerConformer_export(d)
+        
+        self.model_name = model_name
+        self.num_heads = model.encoders[0].self_attn.h
+        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+
+    
+    def prepare_mask(self, mask):
+        if len(mask.shape) == 2:
+            mask = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask = 1 - mask[:, None, :]
+        
+        return mask * -10000.0
+
+    def forward(self,
+                speech: torch.Tensor,
+                speech_lengths: torch.Tensor,
+                ):
+        speech = speech * self._output_size ** 0.5
+        mask = self.make_pad_mask(speech_lengths)
+        mask = self.prepare_mask(mask)
+        if self.embed is None:
+            xs_pad = speech
+        else:
+            xs_pad = self.embed(speech)
+
+        encoder_outs = self.model.encoders(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        if isinstance(xs_pad, tuple):
+            xs_pad = xs_pad[0]
+        xs_pad = self.model.after_norm(xs_pad)
+
+        return xs_pad, speech_lengths
+
+    def get_output_size(self):
+        return self.model.encoders[0].size
+
+    def get_dummy_inputs(self):
+        feats = torch.randn(1, 100, self.feats_dim)
+        return (feats)
+
+    def get_input_names(self):
+        return ['feats']
+
+    def get_output_names(self):
+        return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
+
+    def get_dynamic_axes(self):
+        return {
+            'feats': {
+                1: 'feats_length'
+            },
+            'encoder_out': {
+                1: 'enc_out_length'
+            },
+            'predictor_weight':{
+                1: 'pre_out_length'
+            }
+
+        }
diff --git a/funasr/export/models/modules/decoder_layer.py b/funasr/export/models/modules/decoder_layer.py
index bc306b1..f539452 100644
--- a/funasr/export/models/modules/decoder_layer.py
+++ b/funasr/export/models/modules/decoder_layer.py
@@ -41,3 +41,30 @@
 
         return x, tgt_mask, memory, memory_mask, cache
 
+
+class DecoderLayer(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.src_attn = model.src_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.norm3 = model.norm3
+    
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        residual = tgt
+        tgt_q = tgt
+        tgt_q_mask = tgt_mask
+        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
+
+        residual = x
+        x = self.norm2(x)
+        
+        x = residual + self.src_attn(x, memory, memory, memory_mask)
+
+        residual = x
+        x = self.norm3(x)
+        x = residual + self.feed_forward(x)
+
+        return x, tgt_mask, memory, memory_mask
diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
index 800a4f7..622b109 100644
--- a/funasr/export/models/modules/encoder_layer.py
+++ b/funasr/export/models/modules/encoder_layer.py
@@ -34,4 +34,58 @@
         return x, mask
 
 
+class EncoderLayerConformer(nn.Module):
+    def __init__(
+        self,
+        model,
+    ):
+        """Construct an EncoderLayer object."""
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.feed_forward = model.feed_forward
+        self.feed_forward_macaron = model.feed_forward_macaron
+        self.conv_module = model.conv_module
+        self.norm_ff = model.norm_ff
+        self.norm_mha = model.norm_mha
+        self.norm_ff_macaron = model.norm_ff_macaron
+        self.norm_conv = model.norm_conv
+        self.norm_final = model.norm_final
+        self.size = model.size
 
+    def forward(self, x, mask):
+        if isinstance(x, tuple):
+            x, pos_emb = x[0], x[1]
+        else:
+            x, pos_emb = x, None
+
+        if self.feed_forward_macaron is not None:
+            residual = x
+            x = self.norm_ff_macaron(x)
+            x = residual + self.feed_forward_macaron(x)
+
+        residual = x
+        x = self.norm_mha(x)
+
+        x_q = x
+
+        if pos_emb is not None:
+            x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+        else:
+            x_att = self.self_attn(x_q, x, x, mask)
+        x = residual + x_att
+
+        if self.conv_module is not None:
+            residual = x
+            x = self.norm_conv(x)
+            x = residual +  self.conv_module(x)
+
+        residual = x
+        x = self.norm_ff(x)
+        x = residual + self.feed_forward(x)
+
+        x = self.norm_final(x)
+
+        if pos_emb is not None:
+            return (x, pos_emb), mask
+
+        return x, mask
diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py
index 377b979..7d685f5 100644
--- a/funasr/export/models/modules/multihead_att.py
+++ b/funasr/export/models/modules/multihead_att.py
@@ -4,6 +4,7 @@
 import torch
 import torch.nn as nn
 
+
 class MultiHeadedAttentionSANM(nn.Module):
     def __init__(self, model):
         super().__init__()
@@ -32,7 +33,6 @@
         return x.permute(0, 2, 1, 3)
 
     def forward_qkv(self, x):
-
         q_k_v = self.linear_q_k_v(x)
         q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
         q_h = self.transpose_for_scores(q)
@@ -41,7 +41,6 @@
         return q_h, k_h, v_h, v
 
     def forward_fsmn(self, inputs, mask):
-
         # b, t, d = inputs.size()
         # mask = torch.reshape(mask, (b, -1, 1))
         inputs = inputs * mask
@@ -52,7 +51,6 @@
         x = x + inputs
         x = x * mask
         return x
-
 
     def forward_attention(self, value, scores, mask):
         scores = scores + mask
@@ -65,6 +63,7 @@
         context_layer = context_layer.view(new_context_layer_shape)
         return self.linear_out(context_layer)  # (batch, time1, d_model)
 
+
 class MultiHeadedAttentionSANMDecoder(nn.Module):
     def __init__(self, model):
         super().__init__()
@@ -74,7 +73,6 @@
         self.attn = None
 
     def forward(self, inputs, mask, cache=None):
-
         # b, t, d = inputs.size()
         # mask = torch.reshape(mask, (b, -1, 1))
         inputs = inputs * mask
@@ -91,6 +89,7 @@
         x = x + inputs
         x = x * mask
         return x, cache
+
 
 class MultiHeadedAttentionCrossAtt(nn.Module):
     def __init__(self, model):
@@ -133,3 +132,104 @@
         new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
         context_layer = context_layer.view(new_context_layer_shape)
         return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+class OnnxMultiHeadedAttention(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_q = model.linear_q
+        self.linear_k = model.linear_k
+        self.linear_v = model.linear_v
+        self.linear_out = model.linear_out
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+    
+    def forward(self, query, key, value, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, query, key, value):
+        q = self.linear_q(query)
+        k = self.linear_k(key)
+        v = self.linear_v(value)
+        q = self.transpose_for_scores(q)
+        k = self.transpose_for_scores(k)
+        v = self.transpose_for_scores(v)
+        return q, k, v
+    
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
+    def __init__(self, model):
+        super().__init__(model)
+        self.linear_pos = model.linear_pos
+        self.pos_bias_u = model.pos_bias_u
+        self.pos_bias_v = model.pos_bias_v
+    
+    def forward(self, query, key, value, pos_emb, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, time1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+    def rel_shift(self, x):
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)[
+            :, :, :, : x.size(-1) // 2 + 1
+        ]  # only keep the positions from 0 to time2
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+        
\ No newline at end of file
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 6f4601d..5ea4a34 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -1,9 +1,8 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
+
 import torch
 from torch import nn
-import logging
-import numpy as np
 
 
 def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
@@ -175,3 +174,115 @@
 			max_label_len = frame_len
 	frame_fires = frame_fires[:, :max_label_len, :]
 	return frame_fires, fires
+
+
+class CifPredictorV3(nn.Module):
+	def __init__(self, model):
+		super().__init__()
+		
+		self.pad = model.pad
+		self.cif_conv1d = model.cif_conv1d
+		self.cif_output = model.cif_output
+		self.threshold = model.threshold
+		self.smooth_factor = model.smooth_factor
+		self.noise_threshold = model.noise_threshold
+		self.tail_threshold = model.tail_threshold
+
+		self.upsample_times = model.upsample_times
+		self.upsample_cnn = model.upsample_cnn
+		self.blstm = model.blstm
+		self.cif_output2 = model.cif_output2
+		self.smooth_factor2 = model.smooth_factor2
+		self.noise_threshold2 = model.noise_threshold2
+	
+	def forward(self, hidden: torch.Tensor,
+	            mask: torch.Tensor,
+	            ):
+		h = hidden
+		context = h.transpose(1, 2)
+		queries = self.pad(context)
+		output = torch.relu(self.cif_conv1d(queries))
+		output = output.transpose(1, 2)
+		
+		output = self.cif_output(output)
+		alphas = torch.sigmoid(output)
+		alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+		mask = mask.transpose(-1, -2).float()
+		alphas = alphas * mask
+		alphas = alphas.squeeze(-1)
+		token_num = alphas.sum(-1)
+		
+		mask = mask.squeeze(-1)
+		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+		
+		return acoustic_embeds, token_num, alphas, cif_peak
+	
+	def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
+		h = hidden
+		b = hidden.shape[0]
+		context = h.transpose(1, 2)
+
+		# generate alphas2
+		_output = context
+		output2 = self.upsample_cnn(_output)
+		output2 = output2.transpose(1, 2)
+		output2, (_, _) = self.blstm(output2)
+		alphas2 = torch.sigmoid(self.cif_output2(output2))
+		alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
+		
+		mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
+		mask = mask.unsqueeze(-1)
+		alphas2 = alphas2 * mask
+		alphas2 = alphas2.squeeze(-1)
+		_token_num = alphas2.sum(-1)
+		alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
+		# upsampled alphas and cif_peak
+		us_alphas = alphas2
+		us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
+		return us_alphas, us_cif_peak
+
+	def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+		b, t, d = hidden.size()
+		tail_threshold = self.tail_threshold
+		
+		zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+		ones_t = torch.ones_like(zeros_t)
+
+		mask_1 = torch.cat([mask, zeros_t], dim=1)
+		mask_2 = torch.cat([ones_t, mask], dim=1)
+		mask = mask_2 - mask_1
+		tail_threshold = mask * tail_threshold
+		alphas = torch.cat([alphas, zeros_t], dim=1)
+		alphas = torch.add(alphas, tail_threshold)
+
+		zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+		hidden = torch.cat([hidden, zeros], dim=1)
+		token_num = alphas.sum(dim=-1)
+		token_num_floor = torch.floor(token_num)
+		
+		return hidden, alphas, token_num_floor
+
+
+@torch.jit.script
+def cif_wo_hidden(alphas, threshold: float):
+    batch_size, len_time = alphas.size()
+
+    # loop varss
+    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
+    # intermediate vars along time
+    list_fires = []
+
+    for t in range(len_time):
+        alpha = alphas[:, t]
+
+        integrate += alpha
+        list_fires.append(integrate)
+
+        fire_place = integrate >= threshold
+        integrate = torch.where(fire_place,
+                                integrate - torch.ones([batch_size], device=alphas.device),
+                                integrate)
+
+    fires = torch.stack(list_fires, 1)
+    return fires
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/demo.py b/funasr/runtime/python/onnxruntime/demo.py
index 9c7f2f4..b4a03f3 100644
--- a/funasr/runtime/python/onnxruntime/demo.py
+++ b/funasr/runtime/python/onnxruntime/demo.py
@@ -1,8 +1,10 @@
 
 from rapid_paraformer import Paraformer
+from rapid_paraformer import BiCifParaformer
 
-model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model = Paraformer(model_dir, batch_size=1)
+model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+# model = Paraformer(model_dir, batch_size=1)
+model = BiCifParaformer(model_dir, batch_size=1)
 
 wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav']
 
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py
index f1b5c29..64e0a16 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py
+++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py
@@ -2,3 +2,4 @@
 # @Author: SWHL
 # @Contact: liekkaskono@163.com
 from .paraformer_onnx import Paraformer
+from .paraformer_onnx import BiCifParaformer
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
index a786ef0..d77bcf7 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
+++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
@@ -5,6 +5,7 @@
 from pathlib import Path
 from typing import List, Union, Tuple
 
+import copy
 import librosa
 import numpy as np
 
@@ -13,6 +14,7 @@
                           read_yaml)
 from .utils.postprocess_utils import sentence_postprocess
 from .utils.frontend import WavFrontend
+from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
 
 logging = get_logger()
 
@@ -134,8 +136,67 @@
 
         # Change integer-ids to tokens
         token = self.converter.ids2tokens(token_int)
-        token = token[:valid_token_num-1]
+        # token = token[:valid_token_num-1]
         texts = sentence_postprocess(token)
         text = texts[0]
         # text = self.tokenizer.tokens2text(token)
         return text
+
+
+class BiCifParaformer(Paraformer):
+    def infer(self, feats: np.ndarray,
+              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        am_scores, token_nums, us_alphas, us_cif_peak = self.ort_infer([feats, feats_len])
+        return am_scores, token_nums, us_alphas, us_cif_peak
+    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
+        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
+        waveform_nums = len(waveform_list)
+
+        asr_res = []
+        for beg_idx in range(0, waveform_nums, self.batch_size):
+            res = {}
+            end_idx = min(waveform_nums, beg_idx + self.batch_size)
+            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+            am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len)
+
+            try:
+                am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len)
+            except ONNXRuntimeError:
+                #logging.warning(traceback.format_exc())
+                logging.warning("input wav is silence or noise")
+                preds = ['']
+            else:
+                token = self.decode(am_scores, valid_token_lens)
+                timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(token[0]), log=False)
+                texts = sentence_postprocess(token[0], timestamp)
+                # texts = sentence_postprocess(token[0])
+                text = texts[0]
+            res['text'] = text
+            res['timestamp'] = timestamp
+            asr_res.append(res)
+
+        return asr_res
+
+    def decode_one(self,
+                   am_score: np.ndarray,
+                   valid_token_num: int) -> List[str]:
+        yseq = am_score.argmax(axis=-1)
+        score = am_score.max(axis=-1)
+        score = np.sum(score, axis=-1)
+
+        # pad with mask tokens to ensure compatibility with sos/eos tokens
+        # asr_model.sos:1  asr_model.eos:2
+        yseq = np.array([1] + yseq.tolist() + [2])
+        hyp = Hypothesis(yseq=yseq, score=score)
+
+        # remove sos/eos and get results
+        last_pos = -1
+        token_int = hyp.yseq[1:last_pos].tolist()
+
+        # remove blank symbol id, which is assumed to be 0
+        token_int = list(filter(lambda x: x not in (0, 2), token_int))
+
+        # Change integer-ids to tokens
+        token = self.converter.ids2tokens(token_int)
+        # token = token[:valid_token_num-1]
+        return token
\ No newline at end of file
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index f6a6e98..b82c74a 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -4,6 +4,7 @@
 import numpy as np
 from typing import Any, List, Tuple, Union
 
+
 def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
     if not len(char_list):
         return []

--
Gitblit v1.9.1