8个文件已添加
1 文件已重命名
1 文件已复制
1个文件已删除
| New file |
| | |
| | | ## Using paraformer with libtorch |
| | | |
| | | |
| | | ### Introduction |
| | | - Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary). |
| | | |
| | | ### Steps: |
| | | 1. Export the model. |
| | | - Command: (`Tips`: torch >= 1.11.0 is required.) |
| | | |
| | | ```shell |
| | | python -m funasr.export.export_model [model_name] [export_dir] [true] |
| | | ``` |
| | | `model_name`: the model is to export. |
| | | |
| | | `export_dir`: the dir where the onnx is export. |
| | | |
| | | More details ref to ([export docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)) |
| | | |
| | | - `e.g.`, Export model from modelscope |
| | | ```shell |
| | | python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true |
| | | ``` |
| | | - `e.g.`, Export model from local path, the model'name must be `model.pb`. |
| | | ```shell |
| | | python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true |
| | | ``` |
| | | |
| | | |
| | | 2. Install the `torch_paraformer`. |
| | | - Build the torch_paraformer `whl` |
| | | ```shell |
| | | git clone https://github.com/alibaba/FunASR.git && cd FunASR |
| | | cd funasr/runtime/python/libtorch |
| | | python setup.py bdist_wheel |
| | | ``` |
| | | - Install the build `whl` |
| | | ```bash |
| | | pip install dist/torch_paraformer-0.0.1-py3-none-any.whl |
| | | ``` |
| | | |
| | | 3. Run the demo. |
| | | - Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`. |
| | | - Input: wav formt file, support formats: `str, np.ndarray, List[str]` |
| | | - Output: `List[str]`: recognition result. |
| | | - Example: |
| | | ```python |
| | | from torch_paraformer import Paraformer |
| | | |
| | | model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | model = Paraformer(model_dir, batch_size=1) |
| | | |
| | | wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'] |
| | | |
| | | result = model(wav_path) |
| | | print(result) |
| | | ``` |
| | | |
| | | ## Speed |
| | | |
| | | Environment:Intel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz |
| | | |
| | | Test [wav, 5.53s, 100 times avg.](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav) |
| | | |
| | | | Backend | RTF | |
| | | |:-------:|:-----------------:| |
| | | | Pytorch | 0.110 | |
| | | | Onnx | 0.038 | |
| | | |
| | | ## Acknowledge |
| New file |
| | |
| | | |
| | | from torch_paraformer import Paraformer |
| | | |
| | | model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | model = Paraformer(model_dir, batch_size=1) |
| | | |
| | | wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav'] |
| | | |
| | | result = model(wav_path) |
| | | print(result) |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | from pathlib import Path |
| | | import setuptools |
| | | |
| | | |
| | | def get_readme(): |
| | | root_dir = Path(__file__).resolve().parent |
| | | readme_path = str(root_dir / 'README.md') |
| | | print(readme_path) |
| | | with open(readme_path, 'r', encoding='utf-8') as f: |
| | | readme = f.read() |
| | | return readme |
| | | |
| | | |
| | | |
| | | setuptools.setup( |
| | | name='torch_paraformer', |
| | | version='0.0.1', |
| | | platforms="Any", |
| | | url="https://github.com/alibaba-damo-academy/FunASR.git", |
| | | author="Speech Lab, Alibaba Group, China", |
| | | author_email="funasr@list.alibaba-inc.com", |
| | | description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit", |
| | | license="The MIT License", |
| | | long_description=get_readme(), |
| | | long_description_content_type='text/markdown', |
| | | include_package_data=True, |
| | | install_requires=["librosa", "onnxruntime>=1.7.0", |
| | | "scipy", "numpy>=1.19.3", |
| | | "typeguard", "kaldi-native-fbank", |
| | | "PyYAML>=5.1.2"], |
| | | packages=['torch_paraformer'], |
| | | keywords=[ |
| | | 'funasr,paraformer' |
| | | ], |
| | | classifiers=[ |
| | | 'Programming Language :: Python :: 3.6', |
| | | 'Programming Language :: Python :: 3.7', |
| | | 'Programming Language :: Python :: 3.8', |
| | | 'Programming Language :: Python :: 3.9', |
| | | 'Programming Language :: Python :: 3.10', |
| | | ], |
| | | ) |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | from .paraformer_bin import Paraformer |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | import os.path |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | | |
| | | import copy |
| | | import librosa |
| | | import numpy as np |
| | | |
| | | from .utils.utils import (CharTokenizer, Hypothesis, |
| | | TokenIDConverter, get_logger, |
| | | read_yaml) |
| | | from .utils.postprocess_utils import sentence_postprocess |
| | | from .utils.frontend import WavFrontend |
| | | from funasr.utils.timestamp_tools import time_stamp_lfr6_pl |
| | | logging = get_logger() |
| | | |
| | | import torch |
| | | |
| | | |
| | | class Paraformer(): |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | config_file = os.path.join(model_dir, 'config.yaml') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.converter = TokenIDConverter(config['token_list']) |
| | | self.tokenizer = CharTokenizer() |
| | | self.frontend = WavFrontend( |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'] |
| | | ) |
| | | self.ort_infer = torch.jit.load(model_file) |
| | | self.batch_size = batch_size |
| | | |
| | | def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: |
| | | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) |
| | | waveform_nums = len(waveform_list) |
| | | |
| | | asr_res = [] |
| | | for beg_idx in range(0, waveform_nums, self.batch_size): |
| | | res = {} |
| | | end_idx = min(waveform_nums, beg_idx + self.batch_size) |
| | | feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) |
| | | |
| | | try: |
| | | outputs = self.infer(feats, feats_len) |
| | | outs = outputs[0], outputs[1] |
| | | am_scores, valid_token_lens = outs[0], outs[1] |
| | | if len(outputs) == 4: |
| | | # for BiCifParaformer Inference |
| | | us_alphas, us_cif_peak = outputs[2], outputs[3] |
| | | else: |
| | | us_alphas, us_cif_peak = None, None |
| | | except: |
| | | #logging.warning(traceback.format_exc()) |
| | | logging.warning("input wav is silence or noise") |
| | | preds = [''] |
| | | else: |
| | | am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy() |
| | | preds, raw_token = self.decode(am_scores, valid_token_lens)[0] |
| | | res['preds'] = preds |
| | | if us_cif_peak is not None: |
| | | us_alphas, us_cif_peak = us_alphas.cpu().numpy(), us_cif_peak.cpu().numpy() |
| | | timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(raw_token), log=False) |
| | | res['timestamp'] = timestamp |
| | | asr_res.append(res) |
| | | return asr_res |
| | | |
| | | def load_data(self, |
| | | wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: |
| | | def load_wav(path: str) -> np.ndarray: |
| | | waveform, _ = librosa.load(path, sr=fs) |
| | | return waveform |
| | | |
| | | if isinstance(wav_content, np.ndarray): |
| | | return [wav_content] |
| | | |
| | | if isinstance(wav_content, str): |
| | | return [load_wav(wav_content)] |
| | | |
| | | if isinstance(wav_content, list): |
| | | return [load_wav(path) for path in wav_content] |
| | | |
| | | raise TypeError( |
| | | f'The type of {wav_content} is not in [str, np.ndarray, list]') |
| | | |
| | | def extract_feat(self, |
| | | waveform_list: List[np.ndarray] |
| | | ) -> Tuple[np.ndarray, np.ndarray]: |
| | | feats, feats_len = [], [] |
| | | for waveform in waveform_list: |
| | | speech, _ = self.frontend.fbank(waveform) |
| | | feat, feat_len = self.frontend.lfr_cmvn(speech) |
| | | feats.append(feat) |
| | | feats_len.append(feat_len) |
| | | |
| | | feats = self.pad_feats(feats, np.max(feats_len)) |
| | | feats_len = np.array(feats_len).astype(np.int32) |
| | | return feats, feats_len |
| | | |
| | | @staticmethod |
| | | def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: |
| | | def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: |
| | | pad_width = ((0, max_feat_len - cur_len), (0, 0)) |
| | | return np.pad(feat, pad_width, 'constant', constant_values=0) |
| | | |
| | | feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] |
| | | feats = np.array(feat_res).astype(np.float32) |
| | | return feats |
| | | |
| | | def infer(self, feats: np.ndarray, |
| | | feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | outputs = self.ort_infer([feats, feats_len]) |
| | | return outputs |
| | | |
| | | def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: |
| | | return [self.decode_one(am_score, token_num) |
| | | for am_score, token_num in zip(am_scores, token_nums)] |
| | | |
| | | def decode_one(self, |
| | | am_score: np.ndarray, |
| | | valid_token_num: int) -> List[str]: |
| | | yseq = am_score.argmax(axis=-1) |
| | | score = am_score.max(axis=-1) |
| | | score = np.sum(score, axis=-1) |
| | | |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | # asr_model.sos:1 asr_model.eos:2 |
| | | yseq = np.array([1] + yseq.tolist() + [2]) |
| | | hyp = Hypothesis(yseq=yseq, score=score) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x not in (0, 2), token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | # token = token[:valid_token_num-1] |
| | | texts = sentence_postprocess(token) |
| | | text = texts[0] |
| | | # text = self.tokenizer.tokens2text(token) |
| | | return text, token |
| | | |
copy from funasr/runtime/python/torchscripts/__init__.py
copy to funasr/runtime/python/libtorch/torch_paraformer/utils/__init__.py
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | from pathlib import Path |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | |
| | | import numpy as np |
| | | from typeguard import check_argument_types |
| | | import kaldi_native_fbank as knf |
| | | |
| | | root_dir = Path(__file__).resolve().parent |
| | | |
| | | logger_initialized = {} |
| | | |
| | | |
| | | class WavFrontend(): |
| | | """Conventional frontend structure for ASR. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | cmvn_file: str = None, |
| | | fs: int = 16000, |
| | | window: str = 'hamming', |
| | | n_mels: int = 80, |
| | | frame_length: int = 25, |
| | | frame_shift: int = 10, |
| | | lfr_m: int = 1, |
| | | lfr_n: int = 1, |
| | | dither: float = 1.0, |
| | | **kwargs, |
| | | ) -> None: |
| | | check_argument_types() |
| | | |
| | | opts = knf.FbankOptions() |
| | | opts.frame_opts.samp_freq = fs |
| | | opts.frame_opts.dither = dither |
| | | opts.frame_opts.window_type = window |
| | | opts.frame_opts.frame_shift_ms = float(frame_shift) |
| | | opts.frame_opts.frame_length_ms = float(frame_length) |
| | | opts.mel_opts.num_bins = n_mels |
| | | opts.energy_floor = 0 |
| | | opts.frame_opts.snip_edges = True |
| | | opts.mel_opts.debug_mel = False |
| | | self.opts = opts |
| | | |
| | | self.lfr_m = lfr_m |
| | | self.lfr_n = lfr_n |
| | | self.cmvn_file = cmvn_file |
| | | |
| | | if self.cmvn_file: |
| | | self.cmvn = self.load_cmvn() |
| | | self.fbank_fn = None |
| | | self.fbank_beg_idx = 0 |
| | | self.reset_status() |
| | | |
| | | def fbank(self, |
| | | waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | waveform = waveform * (1 << 15) |
| | | self.fbank_fn = knf.OnlineFbank(self.opts) |
| | | self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) |
| | | frames = self.fbank_fn.num_frames_ready |
| | | mat = np.empty([frames, self.opts.mel_opts.num_bins]) |
| | | for i in range(frames): |
| | | mat[i, :] = self.fbank_fn.get_frame(i) |
| | | feat = mat.astype(np.float32) |
| | | feat_len = np.array(mat.shape[0]).astype(np.int32) |
| | | return feat, feat_len |
| | | |
| | | def fbank_online(self, |
| | | waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | waveform = waveform * (1 << 15) |
| | | # self.fbank_fn = knf.OnlineFbank(self.opts) |
| | | self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) |
| | | frames = self.fbank_fn.num_frames_ready |
| | | mat = np.empty([frames, self.opts.mel_opts.num_bins]) |
| | | for i in range(self.fbank_beg_idx, frames): |
| | | mat[i, :] = self.fbank_fn.get_frame(i) |
| | | # self.fbank_beg_idx += (frames-self.fbank_beg_idx) |
| | | feat = mat.astype(np.float32) |
| | | feat_len = np.array(mat.shape[0]).astype(np.int32) |
| | | return feat, feat_len |
| | | |
| | | def reset_status(self): |
| | | self.fbank_fn = knf.OnlineFbank(self.opts) |
| | | self.fbank_beg_idx = 0 |
| | | |
| | | def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | if self.lfr_m != 1 or self.lfr_n != 1: |
| | | feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n) |
| | | |
| | | if self.cmvn_file: |
| | | feat = self.apply_cmvn(feat) |
| | | |
| | | feat_len = np.array(feat.shape[0]).astype(np.int32) |
| | | return feat, feat_len |
| | | |
| | | @staticmethod |
| | | def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray: |
| | | LFR_inputs = [] |
| | | |
| | | T = inputs.shape[0] |
| | | T_lfr = int(np.ceil(T / lfr_n)) |
| | | left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1)) |
| | | inputs = np.vstack((left_padding, inputs)) |
| | | T = T + (lfr_m - 1) // 2 |
| | | for i in range(T_lfr): |
| | | if lfr_m <= T - i * lfr_n: |
| | | LFR_inputs.append( |
| | | (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1)) |
| | | else: |
| | | # process last LFR frame |
| | | num_padding = lfr_m - (T - i * lfr_n) |
| | | frame = inputs[i * lfr_n:].reshape(-1) |
| | | for _ in range(num_padding): |
| | | frame = np.hstack((frame, inputs[-1])) |
| | | |
| | | LFR_inputs.append(frame) |
| | | LFR_outputs = np.vstack(LFR_inputs).astype(np.float32) |
| | | return LFR_outputs |
| | | |
| | | def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray: |
| | | """ |
| | | Apply CMVN with mvn data |
| | | """ |
| | | frame, dim = inputs.shape |
| | | means = np.tile(self.cmvn[0:1, :dim], (frame, 1)) |
| | | vars = np.tile(self.cmvn[1:2, :dim], (frame, 1)) |
| | | inputs = (inputs + means) * vars |
| | | return inputs |
| | | |
| | | def load_cmvn(self,) -> np.ndarray: |
| | | with open(self.cmvn_file, 'r', encoding='utf-8') as f: |
| | | lines = f.readlines() |
| | | |
| | | means_list = [] |
| | | vars_list = [] |
| | | for i in range(len(lines)): |
| | | line_item = lines[i].split() |
| | | if line_item[0] == '<AddShift>': |
| | | line_item = lines[i + 1].split() |
| | | if line_item[0] == '<LearnRateCoef>': |
| | | add_shift_line = line_item[3:(len(line_item) - 1)] |
| | | means_list = list(add_shift_line) |
| | | continue |
| | | elif line_item[0] == '<Rescale>': |
| | | line_item = lines[i + 1].split() |
| | | if line_item[0] == '<LearnRateCoef>': |
| | | rescale_line = line_item[3:(len(line_item) - 1)] |
| | | vars_list = list(rescale_line) |
| | | continue |
| | | |
| | | means = np.array(means_list).astype(np.float64) |
| | | vars = np.array(vars_list).astype(np.float64) |
| | | cmvn = np.array([means, vars]) |
| | | return cmvn |
| | | |
| | | def load_bytes(input): |
| | | middle_data = np.frombuffer(input, dtype=np.int16) |
| | | middle_data = np.asarray(middle_data) |
| | | if middle_data.dtype.kind not in 'iu': |
| | | raise TypeError("'middle_data' must be an array of integers") |
| | | dtype = np.dtype('float32') |
| | | if dtype.kind != 'f': |
| | | raise TypeError("'dtype' must be a floating point type") |
| | | |
| | | i = np.iinfo(middle_data.dtype) |
| | | abs_max = 2 ** (i.bits - 1) |
| | | offset = i.min + abs_max |
| | | array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) |
| | | return array |
| | | |
| | | |
| | | def test(): |
| | | path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav" |
| | | import librosa |
| | | cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn" |
| | | config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml" |
| | | from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml |
| | | config = read_yaml(config_file) |
| | | waveform, _ = librosa.load(path, sr=None) |
| | | frontend = WavFrontend( |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'], |
| | | ) |
| | | speech, _ = frontend.fbank_online(waveform) #1d, (sample,), numpy |
| | | feat, feat_len = frontend.lfr_cmvn(speech) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450) |
| | | |
| | | frontend.reset_status() # clear cache |
| | | return feat, feat_len |
| | | |
| | | if __name__ == '__main__': |
| | | test() |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | |
| | | import string |
| | | import logging |
| | | from typing import Any, List, Union |
| | | |
| | | |
| | | def isChinese(ch: str): |
| | | if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039': |
| | | return True |
| | | return False |
| | | |
| | | |
| | | def isAllChinese(word: Union[List[Any], str]): |
| | | word_lists = [] |
| | | for i in word: |
| | | cur = i.replace(' ', '') |
| | | cur = cur.replace('</s>', '') |
| | | cur = cur.replace('<s>', '') |
| | | word_lists.append(cur) |
| | | |
| | | if len(word_lists) == 0: |
| | | return False |
| | | |
| | | for ch in word_lists: |
| | | if isChinese(ch) is False: |
| | | return False |
| | | return True |
| | | |
| | | |
| | | def isAllAlpha(word: Union[List[Any], str]): |
| | | word_lists = [] |
| | | for i in word: |
| | | cur = i.replace(' ', '') |
| | | cur = cur.replace('</s>', '') |
| | | cur = cur.replace('<s>', '') |
| | | word_lists.append(cur) |
| | | |
| | | if len(word_lists) == 0: |
| | | return False |
| | | |
| | | for ch in word_lists: |
| | | if ch.isalpha() is False and ch != "'": |
| | | return False |
| | | elif ch.isalpha() is True and isChinese(ch) is True: |
| | | return False |
| | | |
| | | return True |
| | | |
| | | |
| | | # def abbr_dispose(words: List[Any]) -> List[Any]: |
| | | def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: |
| | | words_size = len(words) |
| | | word_lists = [] |
| | | abbr_begin = [] |
| | | abbr_end = [] |
| | | last_num = -1 |
| | | ts_lists = [] |
| | | ts_nums = [] |
| | | ts_index = 0 |
| | | for num in range(words_size): |
| | | if num <= last_num: |
| | | continue |
| | | |
| | | if len(words[num]) == 1 and words[num].encode('utf-8').isalpha(): |
| | | if num + 1 < words_size and words[ |
| | | num + 1] == ' ' and num + 2 < words_size and len( |
| | | words[num + |
| | | 2]) == 1 and words[num + |
| | | 2].encode('utf-8').isalpha(): |
| | | # found the begin of abbr |
| | | abbr_begin.append(num) |
| | | num += 2 |
| | | abbr_end.append(num) |
| | | # to find the end of abbr |
| | | while True: |
| | | num += 1 |
| | | if num < words_size and words[num] == ' ': |
| | | num += 1 |
| | | if num < words_size and len( |
| | | words[num]) == 1 and words[num].encode( |
| | | 'utf-8').isalpha(): |
| | | abbr_end.pop() |
| | | abbr_end.append(num) |
| | | last_num = num |
| | | else: |
| | | break |
| | | else: |
| | | break |
| | | |
| | | for num in range(words_size): |
| | | if words[num] == ' ': |
| | | ts_nums.append(ts_index) |
| | | else: |
| | | ts_nums.append(ts_index) |
| | | ts_index += 1 |
| | | last_num = -1 |
| | | for num in range(words_size): |
| | | if num <= last_num: |
| | | continue |
| | | |
| | | if num in abbr_begin: |
| | | if time_stamp is not None: |
| | | begin = time_stamp[ts_nums[num]][0] |
| | | word_lists.append(words[num].upper()) |
| | | num += 1 |
| | | while num < words_size: |
| | | if num in abbr_end: |
| | | word_lists.append(words[num].upper()) |
| | | last_num = num |
| | | break |
| | | else: |
| | | if words[num].encode('utf-8').isalpha(): |
| | | word_lists.append(words[num].upper()) |
| | | num += 1 |
| | | if time_stamp is not None: |
| | | end = time_stamp[ts_nums[num]][1] |
| | | ts_lists.append([begin, end]) |
| | | else: |
| | | word_lists.append(words[num]) |
| | | if time_stamp is not None and words[num] != ' ': |
| | | begin = time_stamp[ts_nums[num]][0] |
| | | end = time_stamp[ts_nums[num]][1] |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | |
| | | if time_stamp is not None: |
| | | return word_lists, ts_lists |
| | | else: |
| | | return word_lists |
| | | |
| | | |
| | | def sentence_postprocess(words: List[Any], time_stamp: List[List] = None): |
| | | middle_lists = [] |
| | | word_lists = [] |
| | | word_item = '' |
| | | ts_lists = [] |
| | | |
| | | # wash words lists |
| | | for i in words: |
| | | word = '' |
| | | if isinstance(i, str): |
| | | word = i |
| | | else: |
| | | word = i.decode('utf-8') |
| | | |
| | | if word in ['<s>', '</s>', '<unk>']: |
| | | continue |
| | | else: |
| | | middle_lists.append(word) |
| | | |
| | | # all chinese characters |
| | | if isAllChinese(middle_lists): |
| | | for i, ch in enumerate(middle_lists): |
| | | word_lists.append(ch.replace(' ', '')) |
| | | if time_stamp is not None: |
| | | ts_lists = time_stamp |
| | | |
| | | # all alpha characters |
| | | elif isAllAlpha(middle_lists): |
| | | ts_flag = True |
| | | for i, ch in enumerate(middle_lists): |
| | | if ts_flag and time_stamp is not None: |
| | | begin = time_stamp[i][0] |
| | | end = time_stamp[i][1] |
| | | word = '' |
| | | if '@@' in ch: |
| | | word = ch.replace('@@', '') |
| | | word_item += word |
| | | if time_stamp is not None: |
| | | ts_flag = False |
| | | end = time_stamp[i][1] |
| | | else: |
| | | word_item += ch |
| | | word_lists.append(word_item) |
| | | word_lists.append(' ') |
| | | word_item = '' |
| | | if time_stamp is not None: |
| | | ts_flag = True |
| | | end = time_stamp[i][1] |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | |
| | | # mix characters |
| | | else: |
| | | alpha_blank = False |
| | | ts_flag = True |
| | | begin = -1 |
| | | end = -1 |
| | | for i, ch in enumerate(middle_lists): |
| | | if ts_flag and time_stamp is not None: |
| | | begin = time_stamp[i][0] |
| | | end = time_stamp[i][1] |
| | | word = '' |
| | | if isAllChinese(ch): |
| | | if alpha_blank is True: |
| | | word_lists.pop() |
| | | word_lists.append(ch) |
| | | alpha_blank = False |
| | | if time_stamp is not None: |
| | | ts_flag = True |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | elif '@@' in ch: |
| | | word = ch.replace('@@', '') |
| | | word_item += word |
| | | alpha_blank = False |
| | | if time_stamp is not None: |
| | | ts_flag = False |
| | | end = time_stamp[i][1] |
| | | elif isAllAlpha(ch): |
| | | word_item += ch |
| | | word_lists.append(word_item) |
| | | word_lists.append(' ') |
| | | word_item = '' |
| | | alpha_blank = True |
| | | if time_stamp is not None: |
| | | ts_flag = True |
| | | end = time_stamp[i][1] |
| | | ts_lists.append([begin, end]) |
| | | begin = end |
| | | else: |
| | | raise ValueError('invalid character: {}'.format(ch)) |
| | | |
| | | if time_stamp is not None: |
| | | word_lists, ts_lists = abbr_dispose(word_lists, ts_lists) |
| | | real_word_lists = [] |
| | | for ch in word_lists: |
| | | if ch != ' ': |
| | | real_word_lists.append(ch) |
| | | sentence = ' '.join(real_word_lists).strip() |
| | | return sentence, ts_lists, real_word_lists |
| | | else: |
| | | word_lists = abbr_dispose(word_lists) |
| | | real_word_lists = [] |
| | | for ch in word_lists: |
| | | if ch != ' ': |
| | | real_word_lists.append(ch) |
| | | sentence = ''.join(word_lists).strip() |
| | | return sentence, real_word_lists |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import functools |
| | | import logging |
| | | import pickle |
| | | from pathlib import Path |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | |
| | | import numpy as np |
| | | import yaml |
| | | |
| | | from typeguard import check_argument_types |
| | | |
| | | import warnings |
| | | |
| | | root_dir = Path(__file__).resolve().parent |
| | | |
| | | logger_initialized = {} |
| | | |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_list: Union[List, str], |
| | | ): |
| | | check_argument_types() |
| | | |
| | | # self.token_list = self.load_token(token_path) |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | | |
| | | def ids2tokens(self, |
| | | integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | if isinstance(integers, np.ndarray) and integers.ndim != 1: |
| | | raise TokenIDConverterError( |
| | | f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | | token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | if self.unk_symbol not in token2id: |
| | | raise TokenIDConverterError( |
| | | f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | unk_id = token2id[self.unk_symbol] |
| | | return [token2id.get(i, unk_id) for i in tokens] |
| | | |
| | | |
| | | class CharTokenizer(): |
| | | def __init__( |
| | | self, |
| | | symbol_value: Union[Path, str, Iterable[str]] = None, |
| | | space_symbol: str = "<space>", |
| | | remove_non_linguistic_symbols: bool = False, |
| | | ): |
| | | check_argument_types() |
| | | |
| | | self.space_symbol = space_symbol |
| | | self.non_linguistic_symbols = self.load_symbols(symbol_value) |
| | | self.remove_non_linguistic_symbols = remove_non_linguistic_symbols |
| | | |
| | | @staticmethod |
| | | def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set: |
| | | if value is None: |
| | | return set() |
| | | |
| | | if isinstance(value, Iterable[str]): |
| | | return set(value) |
| | | |
| | | file_path = Path(value) |
| | | if not file_path.exists(): |
| | | logging.warning("%s doesn't exist.", file_path) |
| | | return set() |
| | | |
| | | with file_path.open("r", encoding="utf-8") as f: |
| | | return set(line.rstrip() for line in f) |
| | | |
| | | def text2tokens(self, line: Union[str, list]) -> List[str]: |
| | | tokens = [] |
| | | while len(line) != 0: |
| | | for w in self.non_linguistic_symbols: |
| | | if line.startswith(w): |
| | | if not self.remove_non_linguistic_symbols: |
| | | tokens.append(line[: len(w)]) |
| | | line = line[len(w):] |
| | | break |
| | | else: |
| | | t = line[0] |
| | | if t == " ": |
| | | t = "<space>" |
| | | tokens.append(t) |
| | | line = line[1:] |
| | | return tokens |
| | | |
| | | def tokens2text(self, tokens: Iterable[str]) -> str: |
| | | tokens = [t if t != self.space_symbol else " " for t in tokens] |
| | | return "".join(tokens) |
| | | |
| | | def __repr__(self): |
| | | return ( |
| | | f"{self.__class__.__name__}(" |
| | | f'space_symbol="{self.space_symbol}"' |
| | | f'non_linguistic_symbols="{self.non_linguistic_symbols}"' |
| | | f")" |
| | | ) |
| | | |
| | | |
| | | |
| | | class Hypothesis(NamedTuple): |
| | | """Hypothesis data type.""" |
| | | |
| | | yseq: np.ndarray |
| | | score: Union[float, np.ndarray] = 0 |
| | | scores: Dict[str, Union[float, np.ndarray]] = dict() |
| | | states: Dict[str, Any] = dict() |
| | | |
| | | def asdict(self) -> dict: |
| | | """Convert data to JSON-friendly dict.""" |
| | | return self._replace( |
| | | yseq=self.yseq.tolist(), |
| | | score=float(self.score), |
| | | scores={k: float(v) for k, v in self.scores.items()}, |
| | | )._asdict() |
| | | |
| | | |
| | | def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| | | if not Path(yaml_path).exists(): |
| | | raise FileExistsError(f'The {yaml_path} does not exist.') |
| | | |
| | | with open(str(yaml_path), 'rb') as f: |
| | | data = yaml.load(f, Loader=yaml.Loader) |
| | | return data |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='torch_paraformer'): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| | | be directly returned. During initialization, a StreamHandler will always be |
| | | added. |
| | | Args: |
| | | name (str): Logger name. |
| | | Returns: |
| | | logging.Logger: The expected logger. |
| | | """ |
| | | logger = logging.getLogger(name) |
| | | if name in logger_initialized: |
| | | return logger |
| | | |
| | | for logger_name in logger_initialized: |
| | | if name.startswith(logger_name): |
| | | return logger |
| | | |
| | | formatter = logging.Formatter( |
| | | '[%(asctime)s] %(name)s %(levelname)s: %(message)s', |
| | | datefmt="%Y/%m/%d %H:%M:%S") |
| | | |
| | | sh = logging.StreamHandler() |
| | | sh.setFormatter(formatter) |
| | | logger.addHandler(sh) |
| | | logger_initialized[name] = True |
| | | logger.propagate = False |
| | | return logger |