游雁
2023-02-13 865ae89f0a713f70dda16859638b25e7350275ec
export model
1个文件已修改
6个文件已添加
13 文件已重命名
480 ■■■■ 已修改文件
fbank.py 123 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/.gitignore 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/README.md 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/LICENSE 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/feature.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/ivector.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py 48 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py 240 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py 49 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/requirements.txt 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/resources/models/am.mvn 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/resources/models/token_list.pkl 补丁 | 查看 | 原始文档 | blame | 历史
fbank.py
New file
@@ -0,0 +1,123 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
import kaldi_native_fbank as knf
class WavFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    """
    def __init__(
        self,
        cmvn_file: str = None,
        fs: int = 16000,
        window: str = 'hamming',
        n_mels: int = 80,
        frame_length: int = 25,
        frame_shift: int = 10,
        filter_length_min: int = -1,
        filter_length_max: int = -1,
        lfr_m: int = 1,
        lfr_n: int = 1,
        dither: float = 1.0,
        snip_edges: bool = True,
        upsacle_samples: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        self.fs = fs
        self.window = window
        self.n_mels = n_mels
        self.frame_length = frame_length
        self.frame_shift = frame_shift
        self.filter_length_min = filter_length_min
        self.filter_length_max = filter_length_max
        self.lfr_m = lfr_m
        self.lfr_n = lfr_n
        self.cmvn_file = cmvn_file
        self.dither = dither
        self.snip_edges = snip_edges
        self.upsacle_samples = upsacle_samples
    def output_size(self) -> int:
        return self.n_mels * self.lfr_m
    def forward(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            waveform = waveform * (1 << 15)
            waveform = waveform.unsqueeze(0)
            mat = kaldi.fbank(waveform,
                              num_mel_bins=self.n_mels,
                              frame_length=self.frame_length,
                              frame_shift=self.frame_shift,
                              dither=self.dither,
                              energy_floor=0.0,
                              window_type=self.window,
                              sample_frequency=self.fs)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats,
                                 batch_first=True,
                                 padding_value=0.0)
        return feats_pad, feats_lens
import kaldi_native_fbank as knf
def fbank_knf(waveform):
    # sampling_rate = 16000
    # samples = torch.randn(16000 * 10)
    opts = knf.FbankOptions()
    opts.frame_opts.samp_freq = 16000
    opts.frame_opts.dither = 0.0
    opts.frame_opts.window_type = "hamming"
    opts.frame_opts.frame_shift_ms = 10.0
    opts.frame_opts.frame_length_ms = 25.0
    opts.mel_opts.num_bins = 80
    opts.energy_floor = 1
    opts.frame_opts.snip_edges = True
    opts.mel_opts.debug_mel = False
    fbank = knf.OnlineFbank(opts)
    waveform = waveform * (1 << 15)
    fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
    frames = fbank.num_frames_ready
    mat = np.empty([frames, opts.mel_opts.num_bins])
    for i in range(frames):
        mat[i, :] = fbank.get_frame(i)
    return mat
if __name__ == '__main__':
    import librosa
    path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
    waveform, fs = librosa.load(path, sr=None)
    fbank = fbank_knf(waveform)
    frontend = WavFrontend(dither=0.0)
    waveform_tensor = torch.from_numpy(waveform)[None, :]
    fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
    fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
    diff = fbank - fbank_torch
    diff_max = diff.max()
    diff_sum = diff.abs().sum()
    pass
funasr/models/frontend/wav_frontend.py
@@ -171,10 +171,7 @@
                              window_type=self.window,
                              sample_frequency=self.fs)
            # if self.lfr_m != 1 or self.lfr_n != 1:
            #     mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            # if self.cmvn_file is not None:
            #     mat = apply_cmvn(mat, self.cmvn_file)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
funasr/runtime/__init__.py
funasr/runtime/python/__init__.py
funasr/runtime/python/onnxruntime/__init__.py
funasr/runtime/python/onnxruntime/paraformer/.gitignore
funasr/runtime/python/onnxruntime/paraformer/README.md
File was renamed from funasr/runtime/python/onnxruntime/README.md
@@ -29,12 +29,6 @@
        │   └── utils.py
        ├── README.md
        ├── requirements.txt
        ├── resources
        │   ├── config.yaml
        │   └── models
        │       ├── am.mvn
        │       ├── model.onnx  # Put it here.
        │       └── token_list.pkl
        ├── test_onnx.py
        ├── tests
        │   ├── __pycache__
@@ -48,15 +42,15 @@
   - Output: `List[str]`: recognition result.
   - Example:
        ```python
        from rapid_paraformer import RapidParaformer
        from paraformer_onnx import Paraformer
        config_path = 'resources/config.yaml'
        paraformer = RapidParaformer(config_path)
        model = Paraformer(config_path)
        wav_path = ['test_wavs/0478_00017.wav']
        wav_path = ['example/asr_example.wav']
        result = paraformer(wav_path)
        result = model(wav_path)
        print(result)
        ```
funasr/runtime/python/onnxruntime/paraformer/__init__.py
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/LICENSE
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/__init__.py
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/feature.py
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/ivector.py
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
File was renamed from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
@@ -11,25 +12,33 @@
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                    OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
                    read_yaml)
from .postprocess_utils import sentence_postprocess
logging = get_logger()
class RapidParaformer():
    def __init__(self, config_path: Union[str, Path]) -> None:
        if not Path(config_path).exists():
            raise FileNotFoundError(f'{config_path} does not exist.')
class Paraformer():
    def __init__(self, model_dir: Union[str, Path]=None,
                 batch_size: int = 1,
                 device_id: Union[str, int]="-1",
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
        config = read_yaml(config_path)
        model_file = os.path.join(model_dir, 'model.onnx')
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        self.converter = TokenIDConverter(**config['TokenIDConverter'])
        self.tokenizer = CharTokenizer(**config['CharTokenizer'])
        self.converter = TokenIDConverter(config['token_list'])
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=config['WavFrontend']['cmvn_file'],
            **config['WavFrontend']['frontend_conf']
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(config['Model'])
        self.batch_size = config['Model']['batch_size']
        self.ort_infer = OrtInferSession(model_file, device_id)
        self.batch_size = batch_size
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
        waveform_list = self.load_data(wav_content)
@@ -124,16 +133,19 @@
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        text = self.tokenizer.tokens2text(token)
        token = token[:valid_token_num-1]
        texts = sentence_postprocess(token)
        text = texts[0]
        # text = self.tokenizer.tokens2text(token)
        return text[:valid_token_num-1]
if __name__ == '__main__':
    project_dir = Path(__file__).resolve().parent.parent
    cfg_path = project_dir / 'resources' / 'config.yaml'
    paraformer = RapidParaformer(cfg_path)
    model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
    model = Paraformer(model_dir)
    wav_file = '0478_00017.wav'
    for i in range(1000):
        result = paraformer(wav_file)
        print(result)
    wav_file = os.path.join(model_dir, 'example/asr_example.wav')
    result = model(wav_file)
    print(result)
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py
New file
@@ -0,0 +1,240 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
    if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
        return True
    return False
def isAllChinese(word: Union[List[Any], str]):
    word_lists = []
    for i in word:
        cur = i.replace(' ', '')
        cur = cur.replace('</s>', '')
        cur = cur.replace('<s>', '')
        word_lists.append(cur)
    if len(word_lists) == 0:
        return False
    for ch in word_lists:
        if isChinese(ch) is False:
            return False
    return True
def isAllAlpha(word: Union[List[Any], str]):
    word_lists = []
    for i in word:
        cur = i.replace(' ', '')
        cur = cur.replace('</s>', '')
        cur = cur.replace('<s>', '')
        word_lists.append(cur)
    if len(word_lists) == 0:
        return False
    for ch in word_lists:
        if ch.isalpha() is False and ch != "'":
            return False
        elif ch.isalpha() is True and isChinese(ch) is True:
            return False
    return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
    words_size = len(words)
    word_lists = []
    abbr_begin = []
    abbr_end = []
    last_num = -1
    ts_lists = []
    ts_nums = []
    ts_index = 0
    for num in range(words_size):
        if num <= last_num:
            continue
        if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
            if num + 1 < words_size and words[
                    num + 1] == ' ' and num + 2 < words_size and len(
                        words[num +
                              2]) == 1 and words[num +
                                                 2].encode('utf-8').isalpha():
                # found the begin of abbr
                abbr_begin.append(num)
                num += 2
                abbr_end.append(num)
                # to find the end of abbr
                while True:
                    num += 1
                    if num < words_size and words[num] == ' ':
                        num += 1
                        if num < words_size and len(
                                words[num]) == 1 and words[num].encode(
                                    'utf-8').isalpha():
                            abbr_end.pop()
                            abbr_end.append(num)
                            last_num = num
                        else:
                            break
                    else:
                        break
    for num in range(words_size):
        if words[num] == ' ':
            ts_nums.append(ts_index)
        else:
            ts_nums.append(ts_index)
            ts_index += 1
    last_num = -1
    for num in range(words_size):
        if num <= last_num:
            continue
        if num in abbr_begin:
            if time_stamp is not None:
                begin = time_stamp[ts_nums[num]][0]
            word_lists.append(words[num].upper())
            num += 1
            while num < words_size:
                if num in abbr_end:
                    word_lists.append(words[num].upper())
                    last_num = num
                    break
                else:
                    if words[num].encode('utf-8').isalpha():
                        word_lists.append(words[num].upper())
                num += 1
            if time_stamp is not None:
                end = time_stamp[ts_nums[num]][1]
                ts_lists.append([begin, end])
        else:
            word_lists.append(words[num])
            if time_stamp is not None and words[num] != ' ':
                begin = time_stamp[ts_nums[num]][0]
                end = time_stamp[ts_nums[num]][1]
                ts_lists.append([begin, end])
                begin = end
    if time_stamp is not None:
        return word_lists, ts_lists
    else:
        return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
    middle_lists = []
    word_lists = []
    word_item = ''
    ts_lists = []
    # wash words lists
    for i in words:
        word = ''
        if isinstance(i, str):
            word = i
        else:
            word = i.decode('utf-8')
        if word in ['<s>', '</s>', '<unk>']:
            continue
        else:
            middle_lists.append(word)
    # all chinese characters
    if isAllChinese(middle_lists):
        for i, ch in enumerate(middle_lists):
            word_lists.append(ch.replace(' ', ''))
        if time_stamp is not None:
            ts_lists = time_stamp
    # all alpha characters
    elif isAllAlpha(middle_lists):
        ts_flag = True
        for i, ch in enumerate(middle_lists):
            if ts_flag and time_stamp is not None:
                begin = time_stamp[i][0]
                end = time_stamp[i][1]
            word = ''
            if '@@' in ch:
                word = ch.replace('@@', '')
                word_item += word
                if time_stamp is not None:
                    ts_flag = False
                    end = time_stamp[i][1]
            else:
                word_item += ch
                word_lists.append(word_item)
                word_lists.append(' ')
                word_item = ''
                if time_stamp is not None:
                    ts_flag = True
                    end = time_stamp[i][1]
                    ts_lists.append([begin, end])
                    begin = end
    # mix characters
    else:
        alpha_blank = False
        ts_flag = True
        begin = -1
        end = -1
        for i, ch in enumerate(middle_lists):
            if ts_flag and time_stamp is not None:
                begin = time_stamp[i][0]
                end = time_stamp[i][1]
            word = ''
            if isAllChinese(ch):
                if alpha_blank is True:
                    word_lists.pop()
                word_lists.append(ch)
                alpha_blank = False
                if time_stamp is not None:
                    ts_flag = True
                    ts_lists.append([begin, end])
                    begin = end
            elif '@@' in ch:
                word = ch.replace('@@', '')
                word_item += word
                alpha_blank = False
                if time_stamp is not None:
                    ts_flag = False
                    end = time_stamp[i][1]
            elif isAllAlpha(ch):
                word_item += ch
                word_lists.append(word_item)
                word_lists.append(' ')
                word_item = ''
                alpha_blank = True
                if time_stamp is not None:
                    ts_flag = True
                    end = time_stamp[i][1]
                    ts_lists.append([begin, end])
                    begin = end
            else:
                raise ValueError('invalid character: {}'.format(ch))
    if time_stamp is not None:
        word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
        real_word_lists = []
        for ch in word_lists:
            if ch != ' ':
                real_word_lists.append(ch)
        sentence = ' '.join(real_word_lists).strip()
        return sentence, ts_lists, real_word_lists
    else:
        word_lists = abbr_dispose(word_lists)
        real_word_lists = []
        for ch in word_lists:
            if ch != ' ':
                real_word_lists.append(ch)
        sentence = ''.join(word_lists).strip()
        return sentence, real_word_lists
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py
File was renamed from funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py
@@ -14,6 +14,7 @@
from typeguard import check_argument_types
from .kaldifeat import compute_fbank_feats
import warnings
root_dir = Path(__file__).resolve().parent
@@ -21,24 +22,25 @@
class TokenIDConverter():
    def __init__(self, token_path: Union[Path, str],
    def __init__(self, token_list: Union[Path, str],
                 unk_symbol: str = "<unk>",):
        check_argument_types()
        self.token_list = self.load_token(token_path)
        self.unk_symbol = unk_symbol
        # self.token_list = self.load_token(token_path)
        self.token_list = token_list
        self.unk_symbol = token_list[-1]
    @staticmethod
    def load_token(file_path: Union[Path, str]) -> List:
        if not Path(file_path).exists():
            raise TokenIDConverterError(f'The {file_path} does not exist.')
        with open(str(file_path), 'rb') as f:
            token_list = pickle.load(f)
        if len(token_list) != len(set(token_list)):
            raise TokenIDConverterError('The Token exists duplicated symbol.')
        return token_list
    # @staticmethod
    # def load_token(file_path: Union[Path, str]) -> List:
    #     if not Path(file_path).exists():
    #         raise TokenIDConverterError(f'The {file_path} does not exist.')
    #
    #     with open(str(file_path), 'rb') as f:
    #         token_list = pickle.load(f)
    #
    #     if len(token_list) != len(set(token_list)):
    #         raise TokenIDConverterError('The Token exists duplicated symbol.')
    #     return token_list
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
@@ -268,31 +270,36 @@
class OrtInferSession():
    def __init__(self, config):
    def __init__(self, model_file, device_id=-1):
        sess_opt = SessionOptions()
        sess_opt.log_severity_level = 4
        sess_opt.enable_cpu_mem_arena = False
        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        cuda_ep = 'CUDAExecutionProvider'
        cuda_provider_options = {
            "device_id": device_id,
            "arena_extend_strategy": "kNextPowerOfTwo",
            "cudnn_conv_algo_search": "EXHAUSTIVE",
            "do_copy_in_default_stream": "true",
        }
        cpu_ep = 'CPUExecutionProvider'
        cpu_provider_options = {
            "arena_extend_strategy": "kSameAsRequested",
        }
        EP_list = []
        if config['use_cuda'] and get_device() == 'GPU' \
        if device_id != -1 and get_device() == 'GPU' \
                and cuda_ep in get_available_providers():
            EP_list = [(cuda_ep, config[cuda_ep])]
            EP_list = [(cuda_ep, cuda_provider_options)]
        EP_list.append((cpu_ep, cpu_provider_options))
        config['model_path'] = config['model_path']
        self._verify_model(config['model_path'])
        self.session = InferenceSession(config['model_path'],
        self._verify_model(model_file)
        self.session = InferenceSession(model_file,
                                        sess_options=sess_opt,
                                        providers=EP_list)
        if config['use_cuda'] and cuda_ep not in self.session.get_providers():
        if device_id != -1 and cuda_ep not in self.session.get_providers():
            warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
                          'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
                          'you can check their relations from the offical web site: '
funasr/runtime/python/onnxruntime/paraformer/requirements.txt
funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml
File was renamed from funasr/runtime/python/onnxruntime/resources/config.yaml
@@ -18,6 +18,7 @@
    lfr_m: 7
    lfr_n: 6
    filter_length_max: -.inf
    dither: 0.0
Model:
  model_path: resources/models/model.onnx
funasr/runtime/python/onnxruntime/paraformer/resources/models/am.mvn
funasr/runtime/python/onnxruntime/paraformer/resources/models/token_list.pkl
Binary files differ