From ee9569ceef0c9707c8877d6b65733621dfbd3aeb Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期二, 15 八月 2023 17:31:27 +0800
Subject: [PATCH] Contextual Paraformer onnx export

---
 funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py    |   47 ++++
 funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py |  144 +++++++++++++
 funasr/export/models/decoder/contextual_decoder.py              |  191 +++++++++++++++++
 funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py       |    2 
 funasr/export/models/e2e_asr_contextual_paraformer.py           |  174 +++++++++++++++
 funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py |   11 +
 funasr/export/export_model.py                                   |   52 ++--
 funasr/export/models/__init__.py                                |   10 
 8 files changed, 604 insertions(+), 27 deletions(-)

diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 8c3108b..e0a9313 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -1,14 +1,11 @@
-import json
-from typing import Union, Dict
-from pathlib import Path
-
 import os
-import logging
 import torch
-
-from funasr.export.models import get_model
-import numpy as np
 import random
+import logging
+import numpy as np
+from pathlib import Path
+from typing import Union, Dict, List
+from funasr.export.models import get_model
 from funasr.utils.types import str2bool, str2triple_str
 # torch_version = float(".".join(torch.__version__.split(".")[:2]))
 # assert torch_version > 1.9
@@ -55,20 +52,25 @@
 
         # export encoder1
         self.export_config["model_name"] = "model"
-        models = get_model(
+        model = get_model(
             model,
             self.export_config,
         )
-        if not isinstance(models, tuple):
-            models = (models,)
-            
-        for i, model in enumerate(models):
+        if isinstance(model, List):
+            for m in model:
+                m.eval()
+                if self.onnx:
+                    self._export_onnx(m, verbose, export_dir)
+                else:
+                    self._export_torchscripts(m, verbose, export_dir)
+                print("output dir: {}".format(export_dir))
+        else:
             model.eval()
+            # self._export_onnx(model, verbose, export_dir)
             if self.onnx:
                 self._export_onnx(model, verbose, export_dir)
             else:
                 self._export_torchscripts(model, verbose, export_dir)
-    
             print("output dir: {}".format(export_dir))
 
 
@@ -233,17 +235,17 @@
         # model_script = torch.jit.script(model)
         model_script = model #torch.jit.trace(model)
         model_path = os.path.join(path, f'{model.model_name}.onnx')
-        if not os.path.exists(model_path):
-            torch.onnx.export(
-                model_script,
-                dummy_input,
-                model_path,
-                verbose=verbose,
-                opset_version=14,
-                input_names=model.get_input_names(),
-                output_names=model.get_output_names(),
-                dynamic_axes=model.get_dynamic_axes()
-            )
+        # if not os.path.exists(model_path):
+        torch.onnx.export(
+            model_script,
+            dummy_input,
+            model_path,
+            verbose=verbose,
+            opset_version=14,
+            input_names=model.get_input_names(),
+            output_names=model.get_output_names(),
+            dynamic_axes=model.get_dynamic_axes()
+        )
 
         if self.quant:
             from onnxruntime.quantization import QuantType, quantize_dynamic
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index fd0a15c..cba92a8 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -12,9 +12,17 @@
 from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
 from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
 from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
+from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export
+from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
+
 
 def get_model(model, export_config=None):
-    if isinstance(model, BiCifParaformer):
+    if isinstance(model, NeatContextualParaformer):
+        backbone = ContextualParaformer_backbone_export(model, **export_config)
+        embedder = ContextualParaformer_embedder_export(model, **export_config)
+        return [embedder, backbone]
+    elif isinstance(model, BiCifParaformer):
         return BiCifParaformer_export(model, **export_config)
     elif isinstance(model, ParaformerOnline):
         return (ParaformerOnline_encoder_predictor_export(model, model_name="model"),
diff --git a/funasr/export/models/decoder/contextual_decoder.py b/funasr/export/models/decoder/contextual_decoder.py
new file mode 100644
index 0000000..4e11b5d
--- /dev/null
+++ b/funasr/export/models/decoder/contextual_decoder.py
@@ -0,0 +1,191 @@
+import os
+import torch
+import torch.nn as nn
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
+from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
+from funasr.modules.attention import MultiHeadedAttentionCrossAtt
+from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
+from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
+from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
+
+
+class ContextualSANMDecoder(nn.Module):
+    def __init__(self, model,
+                 max_seq_len=512,
+                 model_name='decoder',
+                 onnx: bool = True,):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+        self.model = model
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
+            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
+            self.model.decoders[i] = DecoderLayerSANM_export(d)
+
+        if self.model.decoders2 is not None:
+            for i, d in enumerate(self.model.decoders2):
+                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
+                self.model.decoders2[i] = DecoderLayerSANM_export(d)
+
+        for i, d in enumerate(self.model.decoders3):
+            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+            self.model.decoders3[i] = DecoderLayerSANM_export(d)
+        
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+
+        # bias decoder
+        if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
+            self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
+        self.bias_decoder = self.model.bias_decoder
+        # last decoder
+        if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
+            self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
+        if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
+            self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
+        if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
+            self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
+        self.last_decoder = self.model.last_decoder
+        self.bias_output = self.model.bias_output
+        self.dropout = self.model.dropout
+        
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+    
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        bias_embed: torch.Tensor,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+
+        _, _, x_self_attn, x_src_attn = self.last_decoder(
+            x, tgt_mask, memory, memory_mask
+        )
+
+        # contextual paraformer related
+        contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
+        # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
+        contextual_mask = self.make_pad_mask(contextual_length)
+        contextual_mask, _ = self.prepare_mask(contextual_mask)
+        # import pdb; pdb.set_trace()
+        contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
+        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
+
+        if self.bias_output is not None:
+            x = torch.cat([x_src_attn, cx], dim=2)
+            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
+            x = x_self_attn + self.dropout(x)
+
+        if self.model.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+                x, tgt_mask, memory, memory_mask
+            )
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+            x, tgt_mask, memory, memory_mask
+        )
+        x = self.after_norm(x)
+        x = self.output_layer(x)
+
+        return x, ys_in_lens
+
+
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        memory = torch.randn(1, 100, enc_size)
+        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        cache = [
+            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+            for _ in range(cache_num)
+        ]
+        return (tgt, memory, pre_acoustic_embeds, cache)
+
+    def is_optimizable(self):
+        return True
+
+    def get_input_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+               + ['cache_%d' % i for i in range(cache_num)]
+
+    def get_output_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['y'] \
+               + ['out_cache_%d' % i for i in range(cache_num)]
+
+    def get_dynamic_axes(self):
+        ret = {
+            'tgt': {
+                0: 'tgt_batch',
+                1: 'tgt_length'
+            },
+            'memory': {
+                0: 'memory_batch',
+                1: 'memory_length'
+            },
+            'pre_acoustic_embeds': {
+                0: 'acoustic_embeds_batch',
+                1: 'acoustic_embeds_length',
+            }
+        }
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        ret.update({
+            'cache_%d' % d: {
+                0: 'cache_%d_batch' % d,
+                2: 'cache_%d_length' % d
+            }
+            for d in range(cache_num)
+        })
+        return ret
+
+    def get_model_config(self, path):
+        return {
+            "dec_type": "XformerDecoder",
+            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
+            "n_layers": len(self.model.decoders) + len(self.model.decoders2),
+            "odim": self.model.decoders[0].size
+        }
diff --git a/funasr/export/models/e2e_asr_contextual_paraformer.py b/funasr/export/models/e2e_asr_contextual_paraformer.py
new file mode 100644
index 0000000..61806c9
--- /dev/null
+++ b/funasr/export/models/e2e_asr_contextual_paraformer.py
@@ -0,0 +1,174 @@
+from audioop import bias
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
+from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
+from funasr.models.predictor.cif import CifPredictorV2
+from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
+from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
+from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+
+
+class ContextualParaformer_backbone(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+
+    def __init__(
+            self,
+            model,
+            max_seq_len=512,
+            feats_dim=560,
+            model_name='model',
+            **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+        if isinstance(model.encoder, SANMEncoder):
+            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+        elif isinstance(model.encoder, ConformerEncoder):
+            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+        if isinstance(model.predictor, CifPredictorV2):
+            self.predictor = CifPredictorV2_export(model.predictor)
+        
+        # decoder
+        if isinstance(model.decoder, ContextualParaformerDecoder):
+            self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx)
+        elif isinstance(model.decoder, ParaformerSANMDecoder):
+            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
+        elif isinstance(model.decoder, ParaformerDecoderSAN):
+            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
+        
+        self.feats_dim = feats_dim
+        self.model_name = model_name + '_bb'
+
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            bias_embed: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        # batch = to_device(batch, device=self.device)
+    
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+        pre_token_length = pre_token_length.floor().type(torch.int32)
+
+        # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1])
+
+        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        # sample_ids = decoder_out.argmax(dim=-1)
+        return decoder_out, pre_token_length
+
+    def get_dummy_inputs(self):
+        speech = torch.randn(2, 30, self.feats_dim)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        bias_embed = torch.randn(2, 1, 512)
+        return (speech, speech_lengths, bias_embed)
+
+    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
+        import numpy as np
+        fbank = np.loadtxt(txt_file)
+        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
+        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
+        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
+        return (speech, speech_lengths)
+
+    def get_input_names(self):
+        return ['speech', 'speech_lengths', 'bias_embed']
+
+    def get_output_names(self):
+        return ['logits', 'token_num']
+
+    def get_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'bias_embed': {
+                0: 'batch_size',
+                1: 'num_hotwords'
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+        }
+
+
+class ContextualParaformer_embedder(nn.Module):
+    def __init__(self,
+                 model,
+                 max_seq_len=512,
+                 feats_dim=560,
+                 model_name='model',
+                 **kwargs,):
+        super().__init__()
+        self.embedding = model.bias_embed
+        model.bias_encoder.batch_first = False
+        self.bias_encoder = model.bias_encoder
+        # self.bias_encoder.batch_first = False
+        self.feats_dim = feats_dim
+        self.model_name = "{}_eb".format(model_name)
+    
+    def forward(self, hotword):
+        hotword = self.embedding(hotword).transpose(0, 1) # batch second
+        hw_embed, (_, _) = self.bias_encoder(hotword)
+        return hw_embed
+    
+    def get_dummy_inputs(self):
+        hotword = torch.tensor([
+                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], 
+                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], 
+                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                               ], 
+                                dtype=torch.int32)
+        # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
+        return (hotword)
+
+    def get_input_names(self):
+        return ['hotword']
+
+    def get_output_names(self):
+        return ['hw_embed']
+
+    def get_dynamic_axes(self):
+        return {
+            'hotword': {
+                0: 'num_hotwords',
+            },
+            'hw_embed': {
+                0: 'num_hotwords',
+            },
+        }
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
new file mode 100644
index 0000000..984c0d6
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -0,0 +1,11 @@
+from funasr_onnx import ContextualParaformer
+from pathlib import Path
+
+model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
+model = ContextualParaformer(model_dir, batch_size=1)
+
+wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
+hotwords = '闅忔満鐑瘝 鍚勭鐑瘝 榄旀惌 闃块噷宸村反'
+
+result = model(wav_path, hotwords)
+print(result)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
index 7d8d662..c03d7e5 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,5 +1,5 @@
 # -*- encoding: utf-8 -*-
-from .paraformer_bin import Paraformer
+from .paraformer_bin import Paraformer, ContextualParaformer
 from .vad_bin import Fsmn_vad
 from .vad_bin import Fsmn_vad_online
 from .punc_bin import CT_Transformer
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index f3e0f3d..5f866b8 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -7,6 +7,7 @@
 from typing import List, Union, Tuple
 
 import copy
+import torch
 import librosa
 import numpy as np
 
@@ -16,6 +17,7 @@
 from .utils.postprocess_utils import sentence_postprocess
 from .utils.frontend import WavFrontend
 from .utils.timestamp_utils import time_stamp_lfr6_onnx
+from .utils.utils import pad_list, make_pad_mask
 
 logging = get_logger()
 
@@ -210,3 +212,145 @@
         # texts = sentence_postprocess(token)
         return token
 
+
+class ContextualParaformer(Paraformer):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+    def __init__(self, model_dir: Union[str, Path] = None,
+                 batch_size: int = 1,
+                 device_id: Union[str, int] = "-1",
+                 plot_timestamp_to: str = "",
+                 quantize: bool = False,
+                 intra_op_num_threads: int = 4,
+                 cache_dir: str = None
+                 ):
+
+        if not Path(model_dir).exists():
+            from modelscope.hub.snapshot_download import snapshot_download
+            try:
+                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+            except:
+                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
+        
+        model_bb_file = os.path.join(model_dir, 'model_bb.onnx')
+        model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
+
+        token_list_file = os.path.join(model_dir, 'tokens.txt')
+        self.vocab = {}
+        with open(Path(token_list_file), 'r') as fin:
+            for i, line in enumerate(fin.readlines()):
+                self.vocab[line.strip()] = i
+
+        #if quantize:
+        #    model_file = os.path.join(model_dir, 'model_quant.onnx')
+        #if not os.path.exists(model_file):
+        #    logging.error(".onnx model not exist, please export first.")
+            
+        config_file = os.path.join(model_dir, 'config.yaml')
+        cmvn_file = os.path.join(model_dir, 'am.mvn')
+        config = read_yaml(config_file)
+
+        self.converter = TokenIDConverter(config['token_list'])
+        self.tokenizer = CharTokenizer()
+        self.frontend = WavFrontend(
+            cmvn_file=cmvn_file,
+            **config['frontend_conf']
+        )
+        self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads)
+        self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads)
+
+        self.batch_size = batch_size
+        self.plot_timestamp_to = plot_timestamp_to
+        if "predictor_bias" in config['model_conf'].keys():
+            self.pred_bias = config['model_conf']['predictor_bias']
+        else:
+            self.pred_bias = 0
+
+    def __call__(self, 
+                 wav_content: Union[str, np.ndarray, List[str]], 
+                 hotwords: str,
+                 **kwargs) -> List:
+        # make hotword list
+        hotwords, hotwords_length = self.proc_hotword(hotwords)
+        # import pdb; pdb.set_trace()
+        [bias_embed] = self.eb_infer(hotwords, hotwords_length)
+        # index from bias_embed
+        bias_embed = bias_embed.transpose(1, 0, 2)
+        _ind = np.arange(0, len(hotwords)).tolist()
+        bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
+        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
+        waveform_nums = len(waveform_list)
+        asr_res = []
+        for beg_idx in range(0, waveform_nums, self.batch_size):
+            end_idx = min(waveform_nums, beg_idx + self.batch_size)
+            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+            bias_embed = np.expand_dims(bias_embed, axis=0)
+            bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
+            try:
+                outputs = self.bb_infer(feats, feats_len, bias_embed)
+                am_scores, valid_token_lens = outputs[0], outputs[1]
+            except ONNXRuntimeError:
+                #logging.warning(traceback.format_exc())
+                logging.warning("input wav is silence or noise")
+                preds = ['']
+            else:
+                preds = self.decode(am_scores, valid_token_lens)
+                for pred in preds:
+                    pred = sentence_postprocess(pred)
+                    asr_res.append({'preds': pred})
+        return asr_res
+
+    def proc_hotword(self, hotwords):
+        hotwords = hotwords.split(" ")
+        hotwords_length = [len(i) - 1 for i in hotwords]
+        hotwords_length.append(0)
+        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
+        # hotwords.append('<s>')
+        def word_map(word):
+            return torch.tensor([self.vocab[i] for i in word])
+        hotword_int = [word_map(i) for i in hotwords]
+        # import pdb; pdb.set_trace()
+        hotword_int.append(torch.tensor([1]))
+        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
+        return hotwords, hotwords_length
+
+    def bb_infer(self, feats: np.ndarray,
+              feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]:
+        outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
+        return outputs
+
+    def eb_infer(self, hotwords, hotwords_length):
+        outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
+        return outputs
+
+    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
+        return [self.decode_one(am_score, token_num)
+                for am_score, token_num in zip(am_scores, token_nums)]
+
+    def decode_one(self,
+                   am_score: np.ndarray,
+                   valid_token_num: int) -> List[str]:
+        yseq = am_score.argmax(axis=-1)
+        score = am_score.max(axis=-1)
+        score = np.sum(score, axis=-1)
+
+        # pad with mask tokens to ensure compatibility with sos/eos tokens
+        # asr_model.sos:1  asr_model.eos:2
+        yseq = np.array([1] + yseq.tolist() + [2])
+        hyp = Hypothesis(yseq=yseq, score=score)
+
+        # remove sos/eos and get results
+        last_pos = -1
+        token_int = hyp.yseq[1:last_pos].tolist()
+
+        # remove blank symbol id, which is assumed to be 0
+        token_int = list(filter(lambda x: x not in (0, 2), token_int))
+
+        # Change integer-ids to tokens
+        token = self.converter.ids2tokens(token_int)
+        token = token[:valid_token_num-self.pred_bias]
+        # texts = sentence_postprocess(token)
+        return token
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index f1fc9a0..cf74200 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -7,6 +7,7 @@
 from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
 
 import re
+import torch
 import numpy as np
 import yaml
 try:
@@ -22,6 +23,52 @@
 logger_initialized = {}
 
 
+def pad_list(xs, pad_value, max_len=None):
+    n_batch = len(xs)
+    if max_len is None:
+        max_len = max(x.size(0) for x in xs)
+    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+
+    for i in range(n_batch):
+        pad[i, : xs[i].size(0)] = xs[i]
+
+    return pad
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
+    if length_dim == 0:
+        raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+    if not isinstance(lengths, list):
+        lengths = lengths.tolist()
+    bs = int(len(lengths))
+    if maxlen is None:
+        if xs is None:
+            maxlen = int(max(lengths))
+        else:
+            maxlen = xs.size(length_dim)
+    else:
+        assert xs is None
+        assert maxlen >= int(max(lengths))
+
+    seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+    seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+    seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+    mask = seq_range_expand >= seq_length_expand
+
+    if xs is not None:
+        assert xs.size(0) == bs, (xs.size(0), bs)
+
+        if length_dim < 0:
+            length_dim = xs.dim() + length_dim
+        # ind = (:, None, ..., None, :, , None, ..., None)
+        ind = tuple(
+            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+        )
+        mask = mask[ind].expand_as(xs).to(xs.device)
+    return mask
+
+
 class TokenIDConverter():
     def __init__(self, token_list: Union[List, str],
                  ):

--
Gitblit v1.9.1