shixian.shi
2023-03-03 a1447d12cc7b18a260a4d1cd8ff572f8e78eaba4
funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
@@ -1,29 +1,32 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from cgitb import text
import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
import copy
import librosa
import numpy as np
from utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                    OrtInferSession, TokenIDConverter, get_logger,
                    read_yaml)
from utils.postprocess_utils import sentence_postprocess
from utils.frontend import WavFrontend
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
logging = get_logger()
class Paraformer():
    def __init__(self, model_dir: Union[str, Path]=None,
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int]="-1",
                 device_id: Union[str, int] = "-1",
                 plot_timestamp: bool = False,
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
@@ -40,28 +43,61 @@
        )
        self.ort_infer = OrtInferSession(model_file, device_id)
        self.batch_size = batch_size
        self.plot = plot_timestamp
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            res = {}
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            try:
                am_scores, valid_token_lens = self.infer(feats, feats_len)
                outputs = self.infer(feats, feats_len)
                am_scores, valid_token_lens = outputs[0], outputs[1]
                if len(outputs) == 4:
                    # for BiCifParaformer Inference
                    us_alphas, us_cif_peak = outputs[2], outputs[3]
                else:
                    us_alphas, us_cif_peak = None, None
            except ONNXRuntimeError:
                #logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = ['']
            else:
                preds = self.decode(am_scores, valid_token_lens)
            asr_res.extend(preds)
                preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
                res['preds'] = preds
                if us_cif_peak is not None:
                    timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak, copy.copy(raw_token))
                    res['timestamp'] = timestamp
                    if self.plot:
                        self.plot_wave_timestamp(waveform_list[0], timestamp_total)
            asr_res.append(res)
        return asr_res
    def plot_wave_timestamp(self, wav, text_timestamp):
        # TODO: Plot the wav and timestamp results with matplotlib
        import matplotlib
        matplotlib.use('Agg')
        matplotlib.rc("font", family='Alibaba PuHuiTi')  # set it to a font that your system supports
        import matplotlib.pyplot as plt
        fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
        ax2 = ax1.twinx()
        ax2.set_ylim([0, 2.0])
        # plot waveform
        ax1.set_ylim([-0.3, 0.3])
        time = np.arange(wav.shape[0]) / 16000
        ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4)
        # plot lines and text
        for (char, start, end) in text_timestamp:
            ax1.vlines(start, -0.3, 0.3, ls='--')
            ax1.vlines(end, -0.3, 0.3, ls='--')
            x_adj = 0.045 if char != '<sil>' else 0.12
            ax1.text((start + end) * 0.5 - x_adj, 0, char)
        # plt.legend()
        plotname = "funasr/runtime/python/onnxruntime/debug.png"
        plt.savefig(plotname, bbox_inches='tight')
    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
@@ -107,8 +143,8 @@
    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        am_scores, token_nums = self.ort_infer([feats, feats_len])
        return am_scores, token_nums
        outputs = self.ort_infer([feats, feats_len])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
@@ -135,10 +171,9 @@
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-1]
        # token = token[:valid_token_num-1]
        texts = sentence_postprocess(token)
        text = texts[0]
        # text = self.tokenizer.tokens2text(token)
        return text
        return text, token