zhifu gao
2023-03-02 bfb7c22728329dc4f6b7f528d1bb3464624cc5e7
funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
@@ -2,28 +2,29 @@
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
import copy
import librosa
import numpy as np
from utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                    OrtInferSession, TokenIDConverter, get_logger,
                    read_yaml)
from utils.postprocess_utils import sentence_postprocess
from utils.frontend import WavFrontend
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
logging = get_logger()
class Paraformer():
    def __init__(self, model_dir: Union[str, Path]=None,
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int]="-1",
                 device_id: Union[str, int] = "-1",
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
@@ -135,10 +136,67 @@
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-1]
        # token = token[:valid_token_num-1]
        texts = sentence_postprocess(token)
        text = texts[0]
        # text = self.tokenizer.tokens2text(token)
        return text
class BiCifParaformer(Paraformer):
    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        am_scores, token_nums, us_alphas, us_cif_peak = self.ort_infer([feats, feats_len])
        return am_scores, token_nums, us_alphas, us_cif_peak
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            res = {}
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len)
            try:
                am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len)
            except ONNXRuntimeError:
                #logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = ['']
            else:
                token = self.decode(am_scores, valid_token_lens)
                timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(token[0]), log=False)
                texts = sentence_postprocess(token[0], timestamp)
                # texts = sentence_postprocess(token[0])
                text = texts[0]
            res['text'] = text
            res['timestamp'] = timestamp
            asr_res.append(res)
        return asr_res
    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)
        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
        hyp = Hypothesis(yseq=yseq, score=score)
        # remove sos/eos and get results
        last_pos = -1
        token_int = hyp.yseq[1:last_pos].tolist()
        # remove blank symbol id, which is assumed to be 0
        token_int = list(filter(lambda x: x not in (0, 2), token_int))
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        # token = token[:valid_token_num-1]
        return token