zhifu gao
2024-03-11 a7d7a0f3a2e7cd44a337ced34e3536b12ccb534e
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -5,19 +5,20 @@
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import torch
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.postprocess_utils import (sentence_postprocess,
                                      sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
from .utils.utils import pad_list
logging = get_logger()
@@ -36,7 +37,6 @@
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
@@ -56,25 +56,24 @@
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize)
            
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
@@ -87,6 +86,10 @@
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
        if "lang" in config:
            self.language = config['lang']
        else:
            self.language = None
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -112,7 +115,10 @@
                preds = self.decode(am_scores, valid_token_lens)
                if us_peaks is None:
                    for pred in preds:
                        pred = sentence_postprocess(pred)
                        if self.language == "en-bpe":
                            pred = sentence_postprocess_sentencepiece(pred)
                        else:
                            pred = sentence_postprocess(pred)
                        asr_res.append({'preds': pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
@@ -242,6 +248,13 @@
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
@@ -295,7 +308,7 @@
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        _ind = np.arange(0, len(hotwords)).tolist()
        bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
        bias_embed = bias_embed[_ind, hotwords_length.tolist()]
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
@@ -322,7 +335,7 @@
        hotwords = hotwords.split(" ")
        hotwords_length = [len(i) - 1 for i in hotwords]
        hotwords_length.append(0)
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        hotwords_length = np.array(hotwords_length)
        # hotwords.append('<s>')
        def word_map(word):
            hotwords = []
@@ -332,11 +345,12 @@
                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
                else:
                    hotwords.append(self.vocab[c])
            return torch.tensor(hotwords)
            return np.array(hotwords)
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))
        hotword_int.append(np.array([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        # import pdb; pdb.set_trace()
        return hotwords, hotwords_length
    def bb_infer(self, feats: np.ndarray,
@@ -345,7 +359,7 @@
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: