1个文件已修改
6个文件已添加
13 文件已重命名
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | # Part of the implementation is borrowed from espnet/espnet. |
| | | |
| | | from typing import Tuple |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torchaudio.compliance.kaldi as kaldi |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from typeguard import check_argument_types |
| | | from torch.nn.utils.rnn import pad_sequence |
| | | import kaldi_native_fbank as knf |
| | | |
| | | class WavFrontend(AbsFrontend): |
| | | """Conventional frontend structure for ASR. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | cmvn_file: str = None, |
| | | fs: int = 16000, |
| | | window: str = 'hamming', |
| | | n_mels: int = 80, |
| | | frame_length: int = 25, |
| | | frame_shift: int = 10, |
| | | filter_length_min: int = -1, |
| | | filter_length_max: int = -1, |
| | | lfr_m: int = 1, |
| | | lfr_n: int = 1, |
| | | dither: float = 1.0, |
| | | snip_edges: bool = True, |
| | | upsacle_samples: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self.fs = fs |
| | | self.window = window |
| | | self.n_mels = n_mels |
| | | self.frame_length = frame_length |
| | | self.frame_shift = frame_shift |
| | | self.filter_length_min = filter_length_min |
| | | self.filter_length_max = filter_length_max |
| | | self.lfr_m = lfr_m |
| | | self.lfr_n = lfr_n |
| | | self.cmvn_file = cmvn_file |
| | | self.dither = dither |
| | | self.snip_edges = snip_edges |
| | | self.upsacle_samples = upsacle_samples |
| | | |
| | | def output_size(self) -> int: |
| | | return self.n_mels * self.lfr_m |
| | | |
| | | def forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | batch_size = input.size(0) |
| | | feats = [] |
| | | feats_lens = [] |
| | | for i in range(batch_size): |
| | | waveform_length = input_lengths[i] |
| | | waveform = input[i][:waveform_length] |
| | | waveform = waveform * (1 << 15) |
| | | waveform = waveform.unsqueeze(0) |
| | | mat = kaldi.fbank(waveform, |
| | | num_mel_bins=self.n_mels, |
| | | frame_length=self.frame_length, |
| | | frame_shift=self.frame_shift, |
| | | dither=self.dither, |
| | | energy_floor=0.0, |
| | | window_type=self.window, |
| | | sample_frequency=self.fs) |
| | | |
| | | feat_length = mat.size(0) |
| | | feats.append(mat) |
| | | feats_lens.append(feat_length) |
| | | |
| | | feats_lens = torch.as_tensor(feats_lens) |
| | | feats_pad = pad_sequence(feats, |
| | | batch_first=True, |
| | | padding_value=0.0) |
| | | return feats_pad, feats_lens |
| | | |
| | | import kaldi_native_fbank as knf |
| | | |
| | | def fbank_knf(waveform): |
| | | # sampling_rate = 16000 |
| | | # samples = torch.randn(16000 * 10) |
| | | |
| | | opts = knf.FbankOptions() |
| | | opts.frame_opts.samp_freq = 16000 |
| | | opts.frame_opts.dither = 0.0 |
| | | opts.frame_opts.window_type = "hamming" |
| | | opts.frame_opts.frame_shift_ms = 10.0 |
| | | opts.frame_opts.frame_length_ms = 25.0 |
| | | opts.mel_opts.num_bins = 80 |
| | | opts.energy_floor = 1 |
| | | opts.frame_opts.snip_edges = True |
| | | opts.mel_opts.debug_mel = False |
| | | |
| | | fbank = knf.OnlineFbank(opts) |
| | | waveform = waveform * (1 << 15) |
| | | fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist()) |
| | | frames = fbank.num_frames_ready |
| | | mat = np.empty([frames, opts.mel_opts.num_bins]) |
| | | for i in range(frames): |
| | | mat[i, :] = fbank.get_frame(i) |
| | | return mat |
| | | |
| | | if __name__ == '__main__': |
| | | import librosa |
| | | |
| | | path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav" |
| | | waveform, fs = librosa.load(path, sr=None) |
| | | fbank = fbank_knf(waveform) |
| | | frontend = WavFrontend(dither=0.0) |
| | | waveform_tensor = torch.from_numpy(waveform)[None, :] |
| | | fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)]) |
| | | fbank_torch = fbank_torch.cpu().numpy()[0, :, :] |
| | | diff = fbank - fbank_torch |
| | | diff_max = diff.max() |
| | | diff_sum = diff.abs().sum() |
| | | pass |
| | |
| | | window_type=self.window, |
| | | sample_frequency=self.fs) |
| | | |
| | | # if self.lfr_m != 1 or self.lfr_n != 1: |
| | | # mat = apply_lfr(mat, self.lfr_m, self.lfr_n) |
| | | # if self.cmvn_file is not None: |
| | | # mat = apply_cmvn(mat, self.cmvn_file) |
| | | |
| | | feat_length = mat.size(0) |
| | | feats.append(mat) |
| | | feats_lens.append(feat_length) |
| File was renamed from funasr/runtime/python/onnxruntime/README.md |
| | |
| | | │ └── utils.py |
| | | ├── README.md |
| | | ├── requirements.txt |
| | | ├── resources |
| | | │ ├── config.yaml |
| | | │ └── models |
| | | │ ├── am.mvn |
| | | │ ├── model.onnx # Put it here. |
| | | │ └── token_list.pkl |
| | | ├── test_onnx.py |
| | | ├── tests |
| | | │ ├── __pycache__ |
| | |
| | | - Output: `List[str]`: recognition result. |
| | | - Example: |
| | | ```python |
| | | from rapid_paraformer import RapidParaformer |
| | | from paraformer_onnx import Paraformer |
| | | |
| | | |
| | | config_path = 'resources/config.yaml' |
| | | paraformer = RapidParaformer(config_path) |
| | | model = Paraformer(config_path) |
| | | |
| | | wav_path = ['test_wavs/0478_00017.wav'] |
| | | wav_path = ['example/asr_example.wav'] |
| | | |
| | | result = paraformer(wav_path) |
| | | result = model(wav_path) |
| | | print(result) |
| | | ``` |
| | | |
| File was renamed from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | # @Author: SWHL |
| | | # @Contact: liekkaskono@163.com |
| | | import os.path |
| | | import traceback |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | |
| | | from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, |
| | | OrtInferSession, TokenIDConverter, WavFrontend, get_logger, |
| | | read_yaml) |
| | | from .postprocess_utils import sentence_postprocess |
| | | |
| | | logging = get_logger() |
| | | |
| | | |
| | | class RapidParaformer(): |
| | | def __init__(self, config_path: Union[str, Path]) -> None: |
| | | if not Path(config_path).exists(): |
| | | raise FileNotFoundError(f'{config_path} does not exist.') |
| | | class Paraformer(): |
| | | def __init__(self, model_dir: Union[str, Path]=None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int]="-1", |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | config = read_yaml(config_path) |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | config_file = os.path.join(model_dir, 'config.yaml') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.converter = TokenIDConverter(**config['TokenIDConverter']) |
| | | self.tokenizer = CharTokenizer(**config['CharTokenizer']) |
| | | self.converter = TokenIDConverter(config['token_list']) |
| | | self.tokenizer = CharTokenizer() |
| | | self.frontend = WavFrontend( |
| | | cmvn_file=config['WavFrontend']['cmvn_file'], |
| | | **config['WavFrontend']['frontend_conf'] |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'] |
| | | ) |
| | | self.ort_infer = OrtInferSession(config['Model']) |
| | | self.batch_size = config['Model']['batch_size'] |
| | | self.ort_infer = OrtInferSession(model_file, device_id) |
| | | self.batch_size = batch_size |
| | | |
| | | def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List: |
| | | waveform_list = self.load_data(wav_content) |
| | |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | text = self.tokenizer.tokens2text(token) |
| | | token = token[:valid_token_num-1] |
| | | texts = sentence_postprocess(token) |
| | | text = texts[0] |
| | | # text = self.tokenizer.tokens2text(token) |
| | | return text[:valid_token_num-1] |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | project_dir = Path(__file__).resolve().parent.parent |
| | | cfg_path = project_dir / 'resources' / 'config.yaml' |
| | | paraformer = RapidParaformer(cfg_path) |
| | | model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | model = Paraformer(model_dir) |
| | | |
| | | wav_file = '0478_00017.wav' |
| | | for i in range(1000): |
| | | result = paraformer(wav_file) |
| | | print(result) |
| | | wav_file = os.path.join(model_dir, 'example/asr_example.wav') |
| | | result = model(wav_file) |
| | | print(result) |
| | | |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | |
| | | import string |
| | | import logging |
| | | from typing import Any, List, Union |
| | | |
| | | |
| | | def isChinese(ch: str): |
| | | if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039': |
| | | return True |
| | | return False |
| | | |
| | | |
| | | def isAllChinese(word: Union[List[Any], str]): |
| | | word_lists = [] |
| | | for i in word: |
| | | cur = i.replace(' ', '') |
| | | cur = cur.replace('</s>', '') |
| | | cur = cur.replace('<s>', '') |
| | | word_lists.append(cur) |
| | | |
| | | if len(word_lists) == 0: |
| | | return False |
| | | |
| | | for ch in word_lists: |
| | | if isChinese(ch) is False: |
| | | return False |
| | | return True |
| | | |
| | | |
| | | def isAllAlpha(word: Union[List[Any], str]): |
| | | word_lists = [] |
| | | for i in word: |
| | | cur = i.replace(' ', '') |
| | | cur = cur.replace('</s>', '') |
| | | cur = cur.replace('<s>', '') |
| | | word_lists.append(cur) |
| | | |
| | | if len(word_lists) == 0: |
| | | return False |
| | | |
| | | for ch in word_lists: |
| | | if ch.isalpha() is False and ch != "'": |
| | | return False |
| | | elif ch.isalpha() is True and isChinese(ch) is True: |
| | | return False |
| | | |
| | | return True |
| | | |
| | | |
| | | # def abbr_dispose(words: List[Any]) -> List[Any]: |
| | | def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: |
| | | words_size = len(words) |
| | | word_lists = [] |
| | | abbr_begin = [] |
| | | abbr_end = [] |
| | | last_num = -1 |
| | | ts_lists = [] |
| | | ts_nums = [] |
| | | ts_index = 0 |
| | | for num in range(words_size): |
| | | if num <= last_num: |
| | | continue |
| | | |
| | | if len(words[num]) == 1 and words[num].encode('utf-8').isalpha(): |
| | | if num + 1 < words_size and words[ |
| | | num + 1] == ' ' and num + 2 < words_size and len( |
| | | words[num + |
| | | 2]) == 1 and words[num + |
| | | 2].encode('utf-8').isalpha(): |
| | | # found the begin of abbr |
| | | abbr_begin.append(num) |
| | | num += 2 |
| | | abbr_end.append(num) |
| | | # to find the end of abbr |
| | | while True: |
| | | num += 1 |
| | | if num < words_size and words[num] == ' ': |
| | | num += 1 |
| | | if num < words_size and len( |
| | | words[num]) == 1 and words[num].encode( |
| | | 'utf-8').isalpha(): |
| | | abbr_end.pop() |
| | | abbr_end.append(num) |
| | | last_num = num |
| | | else: |
| | | break |
| | | else: |
| | | break |
| | | |
| | | for num in range(words_size): |
| | | if words[num] == ' ': |
| | | ts_nums.append(ts_index) |
| | | else: |
| | | ts_nums.append(ts_index) |
| | | ts_index += 1 |
| | | last_num = -1 |
| | | for num in range(words_size): |
| | | if num <= last_num: |
| | | continue |
| | | |
| | | if num in abbr_begin: |
| | | if time_stamp is not None: |
| | | begin = time_stamp[ts_nums[num]][0] |
| | | word_lists.append(words[num].upper()) |
| | | num += 1 |
| | | while num < words_size: |
| | | if num in abbr_end: |
| | | word_lists.append(words[num].upper()) |
| | | last_num = num |
| | | break |
| | | else: |
| | | if words[num].encode('utf-8').isalpha(): |
| | | word_lists.append(words[num].upper()) |
| | | num += 1 |
| | | if time_stamp is not None: |
| | | end = time_stamp[ts_nums[num]][1] |
| | | ts_lists.append([begin, end]) |
| | | else: |
| | | word_lists.append(words[num]) |
| | | if time_stamp is not None and words[num] != ' ': |
| | | begin = time_stamp[ts_nums[num]][0] |
| | | end = time_stamp[ts_nums[num]][1] |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | |
| | | if time_stamp is not None: |
| | | return word_lists, ts_lists |
| | | else: |
| | | return word_lists |
| | | |
| | | |
| | | def sentence_postprocess(words: List[Any], time_stamp: List[List] = None): |
| | | middle_lists = [] |
| | | word_lists = [] |
| | | word_item = '' |
| | | ts_lists = [] |
| | | |
| | | # wash words lists |
| | | for i in words: |
| | | word = '' |
| | | if isinstance(i, str): |
| | | word = i |
| | | else: |
| | | word = i.decode('utf-8') |
| | | |
| | | if word in ['<s>', '</s>', '<unk>']: |
| | | continue |
| | | else: |
| | | middle_lists.append(word) |
| | | |
| | | # all chinese characters |
| | | if isAllChinese(middle_lists): |
| | | for i, ch in enumerate(middle_lists): |
| | | word_lists.append(ch.replace(' ', '')) |
| | | if time_stamp is not None: |
| | | ts_lists = time_stamp |
| | | |
| | | # all alpha characters |
| | | elif isAllAlpha(middle_lists): |
| | | ts_flag = True |
| | | for i, ch in enumerate(middle_lists): |
| | | if ts_flag and time_stamp is not None: |
| | | begin = time_stamp[i][0] |
| | | end = time_stamp[i][1] |
| | | word = '' |
| | | if '@@' in ch: |
| | | word = ch.replace('@@', '') |
| | | word_item += word |
| | | if time_stamp is not None: |
| | | ts_flag = False |
| | | end = time_stamp[i][1] |
| | | else: |
| | | word_item += ch |
| | | word_lists.append(word_item) |
| | | word_lists.append(' ') |
| | | word_item = '' |
| | | if time_stamp is not None: |
| | | ts_flag = True |
| | | end = time_stamp[i][1] |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | |
| | | # mix characters |
| | | else: |
| | | alpha_blank = False |
| | | ts_flag = True |
| | | begin = -1 |
| | | end = -1 |
| | | for i, ch in enumerate(middle_lists): |
| | | if ts_flag and time_stamp is not None: |
| | | begin = time_stamp[i][0] |
| | | end = time_stamp[i][1] |
| | | word = '' |
| | | if isAllChinese(ch): |
| | | if alpha_blank is True: |
| | | word_lists.pop() |
| | | word_lists.append(ch) |
| | | alpha_blank = False |
| | | if time_stamp is not None: |
| | | ts_flag = True |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | elif '@@' in ch: |
| | | word = ch.replace('@@', '') |
| | | word_item += word |
| | | alpha_blank = False |
| | | if time_stamp is not None: |
| | | ts_flag = False |
| | | end = time_stamp[i][1] |
| | | elif isAllAlpha(ch): |
| | | word_item += ch |
| | | word_lists.append(word_item) |
| | | word_lists.append(' ') |
| | | word_item = '' |
| | | alpha_blank = True |
| | | if time_stamp is not None: |
| | | ts_flag = True |
| | | end = time_stamp[i][1] |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | else: |
| | | raise ValueError('invalid character: {}'.format(ch)) |
| | | |
| | | if time_stamp is not None: |
| | | word_lists, ts_lists = abbr_dispose(word_lists, ts_lists) |
| | | real_word_lists = [] |
| | | for ch in word_lists: |
| | | if ch != ' ': |
| | | real_word_lists.append(ch) |
| | | sentence = ' '.join(real_word_lists).strip() |
| | | return sentence, ts_lists, real_word_lists |
| | | else: |
| | | word_lists = abbr_dispose(word_lists) |
| | | real_word_lists = [] |
| | | for ch in word_lists: |
| | | if ch != ' ': |
| | | real_word_lists.append(ch) |
| | | sentence = ''.join(word_lists).strip() |
| | | return sentence, real_word_lists |
| File was renamed from funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py |
| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from .kaldifeat import compute_fbank_feats |
| | | import warnings |
| | | |
| | | root_dir = Path(__file__).resolve().parent |
| | | |
| | |
| | | |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_path: Union[Path, str], |
| | | def __init__(self, token_list: Union[Path, str], |
| | | unk_symbol: str = "<unk>",): |
| | | check_argument_types() |
| | | |
| | | self.token_list = self.load_token(token_path) |
| | | self.unk_symbol = unk_symbol |
| | | # self.token_list = self.load_token(token_path) |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | |
| | | @staticmethod |
| | | def load_token(file_path: Union[Path, str]) -> List: |
| | | if not Path(file_path).exists(): |
| | | raise TokenIDConverterError(f'The {file_path} does not exist.') |
| | | |
| | | with open(str(file_path), 'rb') as f: |
| | | token_list = pickle.load(f) |
| | | |
| | | if len(token_list) != len(set(token_list)): |
| | | raise TokenIDConverterError('The Token exists duplicated symbol.') |
| | | return token_list |
| | | # @staticmethod |
| | | # def load_token(file_path: Union[Path, str]) -> List: |
| | | # if not Path(file_path).exists(): |
| | | # raise TokenIDConverterError(f'The {file_path} does not exist.') |
| | | # |
| | | # with open(str(file_path), 'rb') as f: |
| | | # token_list = pickle.load(f) |
| | | # |
| | | # if len(token_list) != len(set(token_list)): |
| | | # raise TokenIDConverterError('The Token exists duplicated symbol.') |
| | | # return token_list |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | |
| | | |
| | | |
| | | class OrtInferSession(): |
| | | def __init__(self, config): |
| | | def __init__(self, model_file, device_id=-1): |
| | | sess_opt = SessionOptions() |
| | | sess_opt.log_severity_level = 4 |
| | | sess_opt.enable_cpu_mem_arena = False |
| | | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
| | | |
| | | cuda_ep = 'CUDAExecutionProvider' |
| | | cuda_provider_options = { |
| | | "device_id": device_id, |
| | | "arena_extend_strategy": "kNextPowerOfTwo", |
| | | "cudnn_conv_algo_search": "EXHAUSTIVE", |
| | | "do_copy_in_default_stream": "true", |
| | | } |
| | | cpu_ep = 'CPUExecutionProvider' |
| | | cpu_provider_options = { |
| | | "arena_extend_strategy": "kSameAsRequested", |
| | | } |
| | | |
| | | EP_list = [] |
| | | if config['use_cuda'] and get_device() == 'GPU' \ |
| | | if device_id != -1 and get_device() == 'GPU' \ |
| | | and cuda_ep in get_available_providers(): |
| | | EP_list = [(cuda_ep, config[cuda_ep])] |
| | | EP_list = [(cuda_ep, cuda_provider_options)] |
| | | EP_list.append((cpu_ep, cpu_provider_options)) |
| | | |
| | | config['model_path'] = config['model_path'] |
| | | self._verify_model(config['model_path']) |
| | | self.session = InferenceSession(config['model_path'], |
| | | self._verify_model(model_file) |
| | | self.session = InferenceSession(model_file, |
| | | sess_options=sess_opt, |
| | | providers=EP_list) |
| | | |
| | | if config['use_cuda'] and cuda_ep not in self.session.get_providers(): |
| | | if device_id != -1 and cuda_ep not in self.session.get_providers(): |
| | | warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' |
| | | 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' |
| | | 'you can check their relations from the offical web site: ' |
| File was renamed from funasr/runtime/python/onnxruntime/resources/config.yaml |
| | |
| | | lfr_m: 7 |
| | | lfr_n: 6 |
| | | filter_length_max: -.inf |
| | | dither: 0.0 |
| | | |
| | | Model: |
| | | model_path: resources/models/model.onnx |