游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -7,17 +7,17 @@
from typing import List, Union, Tuple
import copy
import torch
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.postprocess_utils import (sentence_postprocess,
                                      sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
from .utils.utils import pad_list
logging = get_logger()
@@ -36,7 +36,6 @@
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
@@ -87,6 +86,10 @@
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
        if "lang" in config:
            self.language = config['lang']
        else:
            self.language = None
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -112,7 +115,10 @@
                preds = self.decode(am_scores, valid_token_lens)
                if us_peaks is None:
                    for pred in preds:
                        pred = sentence_postprocess(pred)
                        if self.language == "en-bpe":
                            pred = sentence_postprocess_sentencepiece(pred)
                        else:
                            pred = sentence_postprocess(pred)
                        asr_res.append({'preds': pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
@@ -242,6 +248,13 @@
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
@@ -295,7 +308,7 @@
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        _ind = np.arange(0, len(hotwords)).tolist()
        bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
        bias_embed = bias_embed[_ind, hotwords_length.tolist()]
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
@@ -322,7 +335,7 @@
        hotwords = hotwords.split(" ")
        hotwords_length = [len(i) - 1 for i in hotwords]
        hotwords_length.append(0)
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        hotwords_length = np.array(hotwords_length)
        # hotwords.append('<s>')
        def word_map(word):
            hotwords = []
@@ -332,11 +345,12 @@
                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
                else:
                    hotwords.append(self.vocab[c])
            return torch.tensor(hotwords)
            return np.array(hotwords)
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))
        hotword_int.append(np.array([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        # import pdb; pdb.set_trace()
        return hotwords, hotwords_length
    def bb_infer(self, feats: np.ndarray,
@@ -345,7 +359,7 @@
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: