雾聪
2024-03-14 0cf5dfec2c8313fc2ed2aab8d10bf3dc4b9c283f
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -5,19 +5,20 @@
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import torch
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.postprocess_utils import (sentence_postprocess,
                                      sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
from .utils.utils import pad_list
logging = get_logger()
@@ -36,7 +37,6 @@
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
@@ -56,25 +56,24 @@
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize)
            
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
@@ -87,6 +86,10 @@
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
        if "lang" in config:
            self.language = config['lang']
        else:
            self.language = None
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -112,7 +115,10 @@
                preds = self.decode(am_scores, valid_token_lens)
                if us_peaks is None:
                    for pred in preds:
                        pred = sentence_postprocess(pred)
                        if self.language == "en-bpe":
                            pred = sentence_postprocess_sentencepiece(pred)
                        else:
                            pred = sentence_postprocess(pred)
                        asr_res.append({'preds': pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
@@ -242,6 +248,13 @@
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
@@ -295,7 +308,7 @@
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        _ind = np.arange(0, len(hotwords)).tolist()
        bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
        bias_embed = bias_embed[_ind, hotwords_length.tolist()]
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
@@ -322,7 +335,7 @@
        hotwords = hotwords.split(" ")
        hotwords_length = [len(i) - 1 for i in hotwords]
        hotwords_length.append(0)
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        hotwords_length = np.array(hotwords_length)
        # hotwords.append('<s>')
        def word_map(word):
            hotwords = []
@@ -332,11 +345,12 @@
                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
                else:
                    hotwords.append(self.vocab[c])
            return torch.tensor(hotwords)
            return np.array(hotwords)
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))
        hotword_int.append(np.array([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        # import pdb; pdb.set_trace()
        return hotwords, hotwords_length
    def bb_infer(self, feats: np.ndarray,
@@ -345,7 +359,7 @@
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: