From ee9569ceef0c9707c8877d6b65733621dfbd3aeb Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期二, 15 八月 2023 17:31:27 +0800
Subject: [PATCH] Contextual Paraformer onnx export
---
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py | 47 ++++
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 144 +++++++++++++
funasr/export/models/decoder/contextual_decoder.py | 191 +++++++++++++++++
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py | 2
funasr/export/models/e2e_asr_contextual_paraformer.py | 174 +++++++++++++++
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py | 11 +
funasr/export/export_model.py | 52 ++--
funasr/export/models/__init__.py | 10
8 files changed, 604 insertions(+), 27 deletions(-)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 8c3108b..e0a9313 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -1,14 +1,11 @@
-import json
-from typing import Union, Dict
-from pathlib import Path
-
import os
-import logging
import torch
-
-from funasr.export.models import get_model
-import numpy as np
import random
+import logging
+import numpy as np
+from pathlib import Path
+from typing import Union, Dict, List
+from funasr.export.models import get_model
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
@@ -55,20 +52,25 @@
# export encoder1
self.export_config["model_name"] = "model"
- models = get_model(
+ model = get_model(
model,
self.export_config,
)
- if not isinstance(models, tuple):
- models = (models,)
-
- for i, model in enumerate(models):
+ if isinstance(model, List):
+ for m in model:
+ m.eval()
+ if self.onnx:
+ self._export_onnx(m, verbose, export_dir)
+ else:
+ self._export_torchscripts(m, verbose, export_dir)
+ print("output dir: {}".format(export_dir))
+ else:
model.eval()
+ # self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)
else:
self._export_torchscripts(model, verbose, export_dir)
-
print("output dir: {}".format(export_dir))
@@ -233,17 +235,17 @@
# model_script = torch.jit.script(model)
model_script = model #torch.jit.trace(model)
model_path = os.path.join(path, f'{model.model_name}.onnx')
- if not os.path.exists(model_path):
- torch.onnx.export(
- model_script,
- dummy_input,
- model_path,
- verbose=verbose,
- opset_version=14,
- input_names=model.get_input_names(),
- output_names=model.get_output_names(),
- dynamic_axes=model.get_dynamic_axes()
- )
+ # if not os.path.exists(model_path):
+ torch.onnx.export(
+ model_script,
+ dummy_input,
+ model_path,
+ verbose=verbose,
+ opset_version=14,
+ input_names=model.get_input_names(),
+ output_names=model.get_output_names(),
+ dynamic_axes=model.get_dynamic_axes()
+ )
if self.quant:
from onnxruntime.quantization import QuantType, quantize_dynamic
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index fd0a15c..cba92a8 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -12,9 +12,17 @@
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
+from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export
+from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
+
def get_model(model, export_config=None):
- if isinstance(model, BiCifParaformer):
+ if isinstance(model, NeatContextualParaformer):
+ backbone = ContextualParaformer_backbone_export(model, **export_config)
+ embedder = ContextualParaformer_embedder_export(model, **export_config)
+ return [embedder, backbone]
+ elif isinstance(model, BiCifParaformer):
return BiCifParaformer_export(model, **export_config)
elif isinstance(model, ParaformerOnline):
return (ParaformerOnline_encoder_predictor_export(model, model_name="model"),
diff --git a/funasr/export/models/decoder/contextual_decoder.py b/funasr/export/models/decoder/contextual_decoder.py
new file mode 100644
index 0000000..4e11b5d
--- /dev/null
+++ b/funasr/export/models/decoder/contextual_decoder.py
@@ -0,0 +1,191 @@
+import os
+import torch
+import torch.nn as nn
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
+from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
+from funasr.modules.attention import MultiHeadedAttentionCrossAtt
+from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
+from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
+from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
+
+
+class ContextualSANMDecoder(nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True,):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ self.model = model
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANM_export(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANM_export(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+ self.model.decoders3[i] = DecoderLayerSANM_export(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ # bias decoder
+ if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
+ self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
+ self.bias_decoder = self.model.bias_decoder
+ # last decoder
+ if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
+ self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
+ if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
+ self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
+ if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
+ self.last_decoder = self.model.last_decoder
+ self.bias_output = self.model.bias_output
+ self.dropout = self.model.dropout
+
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ bias_embed: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+
+ _, _, x_self_attn, x_src_attn = self.last_decoder(
+ x, tgt_mask, memory, memory_mask
+ )
+
+ # contextual paraformer related
+ contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
+ # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
+ contextual_mask = self.make_pad_mask(contextual_length)
+ contextual_mask, _ = self.prepare_mask(contextual_mask)
+ # import pdb; pdb.set_trace()
+ contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
+ cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
+
+ if self.bias_output is not None:
+ x = torch.cat([x_src_attn, cx], dim=2)
+ x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
+ x = x_self_attn + self.dropout(x)
+
+ if self.model.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, ys_in_lens
+
+
+ def get_dummy_inputs(self, enc_size):
+ tgt = torch.LongTensor([0]).unsqueeze(0)
+ memory = torch.randn(1, 100, enc_size)
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ cache = [
+ torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+ for _ in range(cache_num)
+ ]
+ return (tgt, memory, pre_acoustic_embeds, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ + ['cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['y'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'tgt': {
+ 0: 'tgt_batch',
+ 1: 'tgt_length'
+ },
+ 'memory': {
+ 0: 'memory_batch',
+ 1: 'memory_length'
+ },
+ 'pre_acoustic_embeds': {
+ 0: 'acoustic_embeds_batch',
+ 1: 'acoustic_embeds_length',
+ }
+ }
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ ret.update({
+ 'cache_%d' % d: {
+ 0: 'cache_%d_batch' % d,
+ 2: 'cache_%d_length' % d
+ }
+ for d in range(cache_num)
+ })
+ return ret
+
+ def get_model_config(self, path):
+ return {
+ "dec_type": "XformerDecoder",
+ "model_path": os.path.join(path, f'{self.model_name}.onnx'),
+ "n_layers": len(self.model.decoders) + len(self.model.decoders2),
+ "odim": self.model.decoders[0].size
+ }
diff --git a/funasr/export/models/e2e_asr_contextual_paraformer.py b/funasr/export/models/e2e_asr_contextual_paraformer.py
new file mode 100644
index 0000000..61806c9
--- /dev/null
+++ b/funasr/export/models/e2e_asr_contextual_paraformer.py
@@ -0,0 +1,174 @@
+from audioop import bias
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
+from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
+from funasr.models.predictor.cif import CifPredictorV2
+from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
+from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
+from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+
+
+class ContextualParaformer_backbone(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+ if isinstance(model.encoder, SANMEncoder):
+ self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+ elif isinstance(model.encoder, ConformerEncoder):
+ self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+ if isinstance(model.predictor, CifPredictorV2):
+ self.predictor = CifPredictorV2_export(model.predictor)
+
+ # decoder
+ if isinstance(model.decoder, ContextualParaformerDecoder):
+ self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx)
+ elif isinstance(model.decoder, ParaformerSANMDecoder):
+ self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
+ elif isinstance(model.decoder, ParaformerDecoderSAN):
+ self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
+
+ self.feats_dim = feats_dim
+ self.model_name = model_name + '_bb'
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ bias_embed: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ # batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.floor().type(torch.int32)
+
+ # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1])
+
+ decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ # sample_ids = decoder_out.argmax(dim=-1)
+ return decoder_out, pre_token_length
+
+ def get_dummy_inputs(self):
+ speech = torch.randn(2, 30, self.feats_dim)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ bias_embed = torch.randn(2, 1, 512)
+ return (speech, speech_lengths, bias_embed)
+
+ def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
+ import numpy as np
+ fbank = np.loadtxt(txt_file)
+ fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
+ speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
+ speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
+ return (speech, speech_lengths)
+
+ def get_input_names(self):
+ return ['speech', 'speech_lengths', 'bias_embed']
+
+ def get_output_names(self):
+ return ['logits', 'token_num']
+
+ def get_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'bias_embed': {
+ 0: 'batch_size',
+ 1: 'num_hotwords'
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
+
+
+class ContextualParaformer_embedder(nn.Module):
+ def __init__(self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='model',
+ **kwargs,):
+ super().__init__()
+ self.embedding = model.bias_embed
+ model.bias_encoder.batch_first = False
+ self.bias_encoder = model.bias_encoder
+ # self.bias_encoder.batch_first = False
+ self.feats_dim = feats_dim
+ self.model_name = "{}_eb".format(model_name)
+
+ def forward(self, hotword):
+ hotword = self.embedding(hotword).transpose(0, 1) # batch second
+ hw_embed, (_, _) = self.bias_encoder(hotword)
+ return hw_embed
+
+ def get_dummy_inputs(self):
+ hotword = torch.tensor([
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=torch.int32)
+ # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
+ return (hotword)
+
+ def get_input_names(self):
+ return ['hotword']
+
+ def get_output_names(self):
+ return ['hw_embed']
+
+ def get_dynamic_axes(self):
+ return {
+ 'hotword': {
+ 0: 'num_hotwords',
+ },
+ 'hw_embed': {
+ 0: 'num_hotwords',
+ },
+ }
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
new file mode 100644
index 0000000..984c0d6
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -0,0 +1,11 @@
+from funasr_onnx import ContextualParaformer
+from pathlib import Path
+
+model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
+model = ContextualParaformer(model_dir, batch_size=1)
+
+wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
+hotwords = '闅忔満鐑瘝 鍚勭鐑瘝 榄旀惌 闃块噷宸村反'
+
+result = model(wav_path, hotwords)
+print(result)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
index 7d8d662..c03d7e5 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
-from .paraformer_bin import Paraformer
+from .paraformer_bin import Paraformer, ContextualParaformer
from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index f3e0f3d..5f866b8 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -7,6 +7,7 @@
from typing import List, Union, Tuple
import copy
+import torch
import librosa
import numpy as np
@@ -16,6 +17,7 @@
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
+from .utils.utils import pad_list, make_pad_mask
logging = get_logger()
@@ -210,3 +212,145 @@
# texts = sentence_postprocess(token)
return token
+
+class ContextualParaformer(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+ def __init__(self, model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ plot_timestamp_to: str = "",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ cache_dir: str = None
+ ):
+
+ if not Path(model_dir).exists():
+ from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
+
+ model_bb_file = os.path.join(model_dir, 'model_bb.onnx')
+ model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
+
+ token_list_file = os.path.join(model_dir, 'tokens.txt')
+ self.vocab = {}
+ with open(Path(token_list_file), 'r') as fin:
+ for i, line in enumerate(fin.readlines()):
+ self.vocab[line.strip()] = i
+
+ #if quantize:
+ # model_file = os.path.join(model_dir, 'model_quant.onnx')
+ #if not os.path.exists(model_file):
+ # logging.error(".onnx model not exist, please export first.")
+
+ config_file = os.path.join(model_dir, 'config.yaml')
+ cmvn_file = os.path.join(model_dir, 'am.mvn')
+ config = read_yaml(config_file)
+
+ self.converter = TokenIDConverter(config['token_list'])
+ self.tokenizer = CharTokenizer()
+ self.frontend = WavFrontend(
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
+ )
+ self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads)
+ self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads)
+
+ self.batch_size = batch_size
+ self.plot_timestamp_to = plot_timestamp_to
+ if "predictor_bias" in config['model_conf'].keys():
+ self.pred_bias = config['model_conf']['predictor_bias']
+ else:
+ self.pred_bias = 0
+
+ def __call__(self,
+ wav_content: Union[str, np.ndarray, List[str]],
+ hotwords: str,
+ **kwargs) -> List:
+ # make hotword list
+ hotwords, hotwords_length = self.proc_hotword(hotwords)
+ # import pdb; pdb.set_trace()
+ [bias_embed] = self.eb_infer(hotwords, hotwords_length)
+ # index from bias_embed
+ bias_embed = bias_embed.transpose(1, 0, 2)
+ _ind = np.arange(0, len(hotwords)).tolist()
+ bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
+ waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
+ waveform_nums = len(waveform_list)
+ asr_res = []
+ for beg_idx in range(0, waveform_nums, self.batch_size):
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
+ feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+ bias_embed = np.expand_dims(bias_embed, axis=0)
+ bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
+ try:
+ outputs = self.bb_infer(feats, feats_len, bias_embed)
+ am_scores, valid_token_lens = outputs[0], outputs[1]
+ except ONNXRuntimeError:
+ #logging.warning(traceback.format_exc())
+ logging.warning("input wav is silence or noise")
+ preds = ['']
+ else:
+ preds = self.decode(am_scores, valid_token_lens)
+ for pred in preds:
+ pred = sentence_postprocess(pred)
+ asr_res.append({'preds': pred})
+ return asr_res
+
+ def proc_hotword(self, hotwords):
+ hotwords = hotwords.split(" ")
+ hotwords_length = [len(i) - 1 for i in hotwords]
+ hotwords_length.append(0)
+ hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
+ # hotwords.append('<s>')
+ def word_map(word):
+ return torch.tensor([self.vocab[i] for i in word])
+ hotword_int = [word_map(i) for i in hotwords]
+ # import pdb; pdb.set_trace()
+ hotword_int.append(torch.tensor([1]))
+ hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
+ return hotwords, hotwords_length
+
+ def bb_infer(self, feats: np.ndarray,
+ feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]:
+ outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
+ return outputs
+
+ def eb_infer(self, hotwords, hotwords_length):
+ outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
+ return outputs
+
+ def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
+ return [self.decode_one(am_score, token_num)
+ for am_score, token_num in zip(am_scores, token_nums)]
+
+ def decode_one(self,
+ am_score: np.ndarray,
+ valid_token_num: int) -> List[str]:
+ yseq = am_score.argmax(axis=-1)
+ score = am_score.max(axis=-1)
+ score = np.sum(score, axis=-1)
+
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ # asr_model.sos:1 asr_model.eos:2
+ yseq = np.array([1] + yseq.tolist() + [2])
+ hyp = Hypothesis(yseq=yseq, score=score)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x not in (0, 2), token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+ token = token[:valid_token_num-self.pred_bias]
+ # texts = sentence_postprocess(token)
+ return token
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index f1fc9a0..cf74200 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -7,6 +7,7 @@
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
+import torch
import numpy as np
import yaml
try:
@@ -22,6 +23,52 @@
logger_initialized = {}
+def pad_list(xs, pad_value, max_len=None):
+ n_batch = len(xs)
+ if max_len is None:
+ max_len = max(x.size(0) for x in xs)
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+
+ for i in range(n_batch):
+ pad[i, : xs[i].size(0)] = xs[i]
+
+ return pad
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if not isinstance(lengths, list):
+ lengths = lengths.tolist()
+ bs = int(len(lengths))
+ if maxlen is None:
+ if xs is None:
+ maxlen = int(max(lengths))
+ else:
+ maxlen = xs.size(length_dim)
+ else:
+ assert xs is None
+ assert maxlen >= int(max(lengths))
+
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ if xs is not None:
+ assert xs.size(0) == bs, (xs.size(0), bs)
+
+ if length_dim < 0:
+ length_dim = xs.dim() + length_dim
+ # ind = (:, None, ..., None, :, , None, ..., None)
+ ind = tuple(
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+ )
+ mask = mask[ind].expand_as(xs).to(xs.device)
+ return mask
+
+
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],
):
--
Gitblit v1.9.1