VirtuosoQ
2024-04-26 e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -5,19 +5,20 @@
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import torch
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.postprocess_utils import (sentence_postprocess,
                                      sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
from .utils.utils import pad_list
logging = get_logger()
@@ -34,7 +35,8 @@
                 plot_timestamp_to: str = "",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 cache_dir: str = None,
                 **kwargs
                 ):
        if not Path(model_dir).exists():
            try:
@@ -55,25 +57,24 @@
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
            
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
@@ -86,6 +87,10 @@
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
        if "lang" in config:
            self.language = config['lang']
        else:
            self.language = None
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -111,7 +116,10 @@
                preds = self.decode(am_scores, valid_token_lens)
                if us_peaks is None:
                    for pred in preds:
                        pred = sentence_postprocess(pred)
                        if self.language == "en-bpe":
                            pred = sentence_postprocess_sentencepiece(pred)
                        else:
                            pred = sentence_postprocess(pred)
                        asr_res.append({'preds': pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
@@ -236,7 +244,8 @@
                 plot_timestamp_to: str = "",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 cache_dir: str = None,
                 **kwargs
                 ):
        if not Path(model_dir).exists():
@@ -259,22 +268,32 @@
            model_bb_file = os.path.join(model_dir, 'model.onnx')
            model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
        token_list_file = os.path.join(model_dir, 'tokens.txt')
        self.vocab = {}
        with open(Path(token_list_file), 'r') as fin:
            for i, line in enumerate(fin.readlines()):
                self.vocab[line.strip()] = i
        if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
        #if quantize:
        #    model_file = os.path.join(model_dir, 'model_quant.onnx')
        #if not os.path.exists(model_file):
        #    logging.error(".onnx model not exist, please export first.")
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
            
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        # revert token_list into vocab dict
        self.vocab = {}
        for i, token in enumerate(token_list):
                self.vocab[token] = i
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
@@ -301,7 +320,7 @@
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        _ind = np.arange(0, len(hotwords)).tolist()
        bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
        bias_embed = bias_embed[_ind, hotwords_length.tolist()]
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
@@ -328,7 +347,7 @@
        hotwords = hotwords.split(" ")
        hotwords_length = [len(i) - 1 for i in hotwords]
        hotwords_length.append(0)
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        hotwords_length = np.array(hotwords_length)
        # hotwords.append('<s>')
        def word_map(word):
            hotwords = []
@@ -338,11 +357,12 @@
                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
                else:
                    hotwords.append(self.vocab[c])
            return torch.tensor(hotwords)
            return np.array(hotwords)
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))
        hotword_int.append(np.array([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        # import pdb; pdb.set_trace()
        return hotwords, hotwords_length
    def bb_infer(self, feats: np.ndarray,
@@ -351,7 +371,7 @@
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
@@ -381,4 +401,10 @@
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-self.pred_bias]
        # texts = sentence_postprocess(token)
        return token
        return token
class SeacoParaformer(ContextualParaformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # no difference with contextual_paraformer in method of calling onnx models