Lizerui9926
2023-11-09 ee1eefff68e25f2e7674616be34518b07d8135c3
Merge pull request #1075 from alibaba-damo-academy/dev_lzr_en

fix paraformer-en model python onnx postprocess
2个文件已修改
63 ■■■■■ 已修改文件
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py 51 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -14,7 +14,8 @@
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.postprocess_utils import (sentence_postprocess,
                                      sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
@@ -86,6 +87,10 @@
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
        if "lang" in config:
            self.language = config['lang']
        else:
            self.language = None
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -111,7 +116,10 @@
                preds = self.decode(am_scores, valid_token_lens)
                if us_peaks is None:
                    for pred in preds:
                        pred = sentence_postprocess(pred)
                        if self.language == "en-bpe":
                            pred = sentence_postprocess_sentencepiece(pred)
                        else:
                            pred = sentence_postprocess(pred)
                        asr_res.append({'preds': pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py
@@ -240,3 +240,54 @@
                real_word_lists.append(ch)
        sentence = ''.join(word_lists).strip()
        return sentence, real_word_lists
def sentence_postprocess_sentencepiece(words):
    middle_lists = []
    word_lists = []
    word_item = ''
    # wash words lists
    for i in words:
        word = ''
        if isinstance(i, str):
            word = i
        else:
            word = i.decode('utf-8')
        if word in ['<s>', '</s>', '<unk>', '<OOV>']:
            continue
        else:
            middle_lists.append(word)
    # all alpha characters
    for i, ch in enumerate(middle_lists):
        word = ''
        if '\u2581' in ch and i == 0:
            word_item = ''
            word = ch.replace('\u2581', '')
            word_item += word
        elif '\u2581' in ch and i != 0:
            word_lists.append(word_item)
            word_lists.append(' ')
            word_item = ''
            word = ch.replace('\u2581', '')
            word_item += word
        else:
            word_item += ch
    if word_item is not None:
        word_lists.append(word_item)
    #word_lists = abbr_dispose(word_lists)
    real_word_lists = []
    for ch in word_lists:
        if ch != ' ':
            if ch == "i":
                ch = ch.replace("i", "I")
            elif ch == "i'm":
                ch = ch.replace("i'm", "I'm")
            elif ch == "i've":
                ch = ch.replace("i've", "I've")
            elif ch == "i'll":
                ch = ch.replace("i'll", "I'll")
            real_word_lists.append(ch)
    sentence = ''.join(word_lists)
    return sentence, real_word_lists