shixian.shi
2023-08-15 ee9569ceef0c9707c8877d6b65733621dfbd3aeb
Contextual Paraformer onnx export
5个文件已修改
3个文件已添加
631 ■■■■■ 已修改文件
funasr/export/export_model.py 52 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/decoder/contextual_decoder.py 191 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_asr_contextual_paraformer.py 174 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 144 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py 47 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py
@@ -1,14 +1,11 @@
import json
from typing import Union, Dict
from pathlib import Path
import os
import logging
import torch
from funasr.export.models import get_model
import numpy as np
import random
import logging
import numpy as np
from pathlib import Path
from typing import Union, Dict, List
from funasr.export.models import get_model
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
@@ -55,20 +52,25 @@
        # export encoder1
        self.export_config["model_name"] = "model"
        models = get_model(
        model = get_model(
            model,
            self.export_config,
        )
        if not isinstance(models, tuple):
            models = (models,)
        for i, model in enumerate(models):
        if isinstance(model, List):
            for m in model:
                m.eval()
                if self.onnx:
                    self._export_onnx(m, verbose, export_dir)
                else:
                    self._export_torchscripts(m, verbose, export_dir)
                print("output dir: {}".format(export_dir))
        else:
            model.eval()
            # self._export_onnx(model, verbose, export_dir)
            if self.onnx:
                self._export_onnx(model, verbose, export_dir)
            else:
                self._export_torchscripts(model, verbose, export_dir)
            print("output dir: {}".format(export_dir))
@@ -233,17 +235,17 @@
        # model_script = torch.jit.script(model)
        model_script = model #torch.jit.trace(model)
        model_path = os.path.join(path, f'{model.model_name}.onnx')
        if not os.path.exists(model_path):
            torch.onnx.export(
                model_script,
                dummy_input,
                model_path,
                verbose=verbose,
                opset_version=14,
                input_names=model.get_input_names(),
                output_names=model.get_output_names(),
                dynamic_axes=model.get_dynamic_axes()
            )
        # if not os.path.exists(model_path):
        torch.onnx.export(
            model_script,
            dummy_input,
            model_path,
            verbose=verbose,
            opset_version=14,
            input_names=model.get_input_names(),
            output_names=model.get_output_names(),
            dynamic_axes=model.get_dynamic_axes()
        )
        if self.quant:
            from onnxruntime.quantization import QuantType, quantize_dynamic
funasr/export/models/__init__.py
@@ -12,9 +12,17 @@
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export
from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
def get_model(model, export_config=None):
    if isinstance(model, BiCifParaformer):
    if isinstance(model, NeatContextualParaformer):
        backbone = ContextualParaformer_backbone_export(model, **export_config)
        embedder = ContextualParaformer_embedder_export(model, **export_config)
        return [embedder, backbone]
    elif isinstance(model, BiCifParaformer):
        return BiCifParaformer_export(model, **export_config)
    elif isinstance(model, ParaformerOnline):
        return (ParaformerOnline_encoder_predictor_export(model, model_name="model"),
funasr/export/models/decoder/contextual_decoder.py
New file
@@ -0,0 +1,191 @@
import os
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
from funasr.modules.attention import MultiHeadedAttentionCrossAtt
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
class ContextualSANMDecoder(nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True,):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANM_export(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANM_export(d)
        for i, d in enumerate(self.model.decoders3):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            self.model.decoders3[i] = DecoderLayerSANM_export(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
        # bias decoder
        if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
            self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
        self.bias_decoder = self.model.bias_decoder
        # last decoder
        if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
            self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
        if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
            self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
        if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
            self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
        self.last_decoder = self.model.last_decoder
        self.bias_output = self.model.bias_output
        self.dropout = self.model.dropout
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        bias_embed: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )
        _, _, x_self_attn, x_src_attn = self.last_decoder(
            x, tgt_mask, memory, memory_mask
        )
        # contextual paraformer related
        contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
        # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
        contextual_mask = self.make_pad_mask(contextual_length)
        contextual_mask, _ = self.prepare_mask(contextual_mask)
        # import pdb; pdb.set_trace()
        contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
        if self.bias_output is not None:
            x = torch.cat([x_src_attn, cx], dim=2)
            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
            x = x_self_attn + self.dropout(x)
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, ys_in_lens
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        return ret
    def get_model_config(self, path):
        return {
            "dec_type": "XformerDecoder",
            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
            "n_layers": len(self.model.decoders) + len(self.model.decoders2),
            "odim": self.model.decoders[0].size
        }
funasr/export/models/e2e_asr_contextual_paraformer.py
New file
@@ -0,0 +1,174 @@
from audioop import bias
import logging
import torch
import torch.nn as nn
import numpy as np
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
from funasr.models.predictor.cif import CifPredictorV2
from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
class ContextualParaformer_backbone(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        elif isinstance(model.encoder, ConformerEncoder):
            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
        if isinstance(model.predictor, CifPredictorV2):
            self.predictor = CifPredictorV2_export(model.predictor)
        # decoder
        if isinstance(model.decoder, ContextualParaformerDecoder):
            self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx)
        elif isinstance(model.decoder, ParaformerSANMDecoder):
            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
        elif isinstance(model.decoder, ParaformerDecoderSAN):
            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
        self.feats_dim = feats_dim
        self.model_name = model_name + '_bb'
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            bias_embed: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.floor().type(torch.int32)
        # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1])
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # sample_ids = decoder_out.argmax(dim=-1)
        return decoder_out, pre_token_length
    def get_dummy_inputs(self):
        speech = torch.randn(2, 30, self.feats_dim)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        bias_embed = torch.randn(2, 1, 512)
        return (speech, speech_lengths, bias_embed)
    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
        import numpy as np
        fbank = np.loadtxt(txt_file)
        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
        return (speech, speech_lengths)
    def get_input_names(self):
        return ['speech', 'speech_lengths', 'bias_embed']
    def get_output_names(self):
        return ['logits', 'token_num']
    def get_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'bias_embed': {
                0: 'batch_size',
                1: 'num_hotwords'
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
class ContextualParaformer_embedder(nn.Module):
    def __init__(self,
                 model,
                 max_seq_len=512,
                 feats_dim=560,
                 model_name='model',
                 **kwargs,):
        super().__init__()
        self.embedding = model.bias_embed
        model.bias_encoder.batch_first = False
        self.bias_encoder = model.bias_encoder
        # self.bias_encoder.batch_first = False
        self.feats_dim = feats_dim
        self.model_name = "{}_eb".format(model_name)
    def forward(self, hotword):
        hotword = self.embedding(hotword).transpose(0, 1) # batch second
        hw_embed, (_, _) = self.bias_encoder(hotword)
        return hw_embed
    def get_dummy_inputs(self):
        hotword = torch.tensor([
                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                               ],
                                dtype=torch.int32)
        # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
        return (hotword)
    def get_input_names(self):
        return ['hotword']
    def get_output_names(self):
        return ['hw_embed']
    def get_dynamic_axes(self):
        return {
            'hotword': {
                0: 'num_hotwords',
            },
            'hw_embed': {
                0: 'num_hotwords',
            },
        }
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
New file
@@ -0,0 +1,11 @@
from funasr_onnx import ContextualParaformer
from pathlib import Path
model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
result = model(wav_path, hotwords)
print(result)
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
from .paraformer_bin import Paraformer, ContextualParaformer
from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -7,6 +7,7 @@
from typing import List, Union, Tuple
import copy
import torch
import librosa
import numpy as np
@@ -16,6 +17,7 @@
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
logging = get_logger()
@@ -210,3 +212,145 @@
        # texts = sentence_postprocess(token)
        return token
class ContextualParaformer(Paraformer):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 plot_timestamp_to: str = "",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
        if not Path(model_dir).exists():
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
        model_bb_file = os.path.join(model_dir, 'model_bb.onnx')
        model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
        token_list_file = os.path.join(model_dir, 'tokens.txt')
        self.vocab = {}
        with open(Path(token_list_file), 'r') as fin:
            for i, line in enumerate(fin.readlines()):
                self.vocab[line.strip()] = i
        #if quantize:
        #    model_file = os.path.join(model_dir, 'model_quant.onnx')
        #if not os.path.exists(model_file):
        #    logging.error(".onnx model not exist, please export first.")
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        self.converter = TokenIDConverter(config['token_list'])
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.plot_timestamp_to = plot_timestamp_to
        if "predictor_bias" in config['model_conf'].keys():
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
    def __call__(self,
                 wav_content: Union[str, np.ndarray, List[str]],
                 hotwords: str,
                 **kwargs) -> List:
        # make hotword list
        hotwords, hotwords_length = self.proc_hotword(hotwords)
        # import pdb; pdb.set_trace()
        [bias_embed] = self.eb_infer(hotwords, hotwords_length)
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        _ind = np.arange(0, len(hotwords)).tolist()
        bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            bias_embed = np.expand_dims(bias_embed, axis=0)
            bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
            try:
                outputs = self.bb_infer(feats, feats_len, bias_embed)
                am_scores, valid_token_lens = outputs[0], outputs[1]
            except ONNXRuntimeError:
                #logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = ['']
            else:
                preds = self.decode(am_scores, valid_token_lens)
                for pred in preds:
                    pred = sentence_postprocess(pred)
                    asr_res.append({'preds': pred})
        return asr_res
    def proc_hotword(self, hotwords):
        hotwords = hotwords.split(" ")
        hotwords_length = [len(i) - 1 for i in hotwords]
        hotwords_length.append(0)
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        # hotwords.append('<s>')
        def word_map(word):
            return torch.tensor([self.vocab[i] for i in word])
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        return hotwords, hotwords_length
    def bb_infer(self, feats: np.ndarray,
              feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]
    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)
        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
        hyp = Hypothesis(yseq=yseq, score=score)
        # remove sos/eos and get results
        last_pos = -1
        token_int = hyp.yseq[1:last_pos].tolist()
        # remove blank symbol id, which is assumed to be 0
        token_int = list(filter(lambda x: x not in (0, 2), token_int))
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-self.pred_bias]
        # texts = sentence_postprocess(token)
        return token
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -7,6 +7,7 @@
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
import torch
import numpy as np
import yaml
try:
@@ -22,6 +23,52 @@
logger_initialized = {}
def pad_list(xs, pad_value, max_len=None):
    n_batch = len(xs)
    if max_len is None:
        max_len = max(x.size(0) for x in xs)
    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
    for i in range(n_batch):
        pad[i, : xs[i].size(0)] = xs[i]
    return pad
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
    if length_dim == 0:
        raise ValueError("length_dim cannot be 0: {}".format(length_dim))
    if not isinstance(lengths, list):
        lengths = lengths.tolist()
    bs = int(len(lengths))
    if maxlen is None:
        if xs is None:
            maxlen = int(max(lengths))
        else:
            maxlen = xs.size(length_dim)
    else:
        assert xs is None
        assert maxlen >= int(max(lengths))
    seq_range = torch.arange(0, maxlen, dtype=torch.int64)
    seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
    seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    if xs is not None:
        assert xs.size(0) == bs, (xs.size(0), bs)
        if length_dim < 0:
            length_dim = xs.dim() + length_dim
        # ind = (:, None, ..., None, :, , None, ..., None)
        ind = tuple(
            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
        )
        mask = mask[ind].expand_as(xs).to(xs.device)
    return mask
class TokenIDConverter():
    def __init__(self, token_list: Union[List, str],
                 ):