| | |
| | | from pathlib import Path |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | |
| | | import re |
| | | import torch |
| | | import numpy as np |
| | | import yaml |
| | | from onnxruntime import (GraphOptimizationLevel, InferenceSession, |
| | | SessionOptions, get_available_providers, get_device) |
| | | from typeguard import check_argument_types |
| | | |
| | | try: |
| | | from onnxruntime import (GraphOptimizationLevel, InferenceSession, |
| | | SessionOptions, get_available_providers, get_device) |
| | | except: |
| | | print("please pip3 install onnxruntime") |
| | | import jieba |
| | | import warnings |
| | | |
| | | root_dir = Path(__file__).resolve().parent |
| | |
| | | logger_initialized = {} |
| | | |
| | | |
| | | def pad_list(xs, pad_value, max_len=None): |
| | | n_batch = len(xs) |
| | | if max_len is None: |
| | | max_len = max(x.size(0) for x in xs) |
| | | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) |
| | | |
| | | for i in range(n_batch): |
| | | pad[i, : xs[i].size(0)] = xs[i] |
| | | |
| | | return pad |
| | | |
| | | |
| | | def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): |
| | | if length_dim == 0: |
| | | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
| | | |
| | | if not isinstance(lengths, list): |
| | | lengths = lengths.tolist() |
| | | bs = int(len(lengths)) |
| | | if maxlen is None: |
| | | if xs is None: |
| | | maxlen = int(max(lengths)) |
| | | else: |
| | | maxlen = xs.size(length_dim) |
| | | else: |
| | | assert xs is None |
| | | assert maxlen >= int(max(lengths)) |
| | | |
| | | seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
| | | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
| | | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
| | | mask = seq_range_expand >= seq_length_expand |
| | | |
| | | if xs is not None: |
| | | assert xs.size(0) == bs, (xs.size(0), bs) |
| | | |
| | | if length_dim < 0: |
| | | length_dim = xs.dim() + length_dim |
| | | # ind = (:, None, ..., None, :, , None, ..., None) |
| | | ind = tuple( |
| | | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
| | | ) |
| | | mask = mask[ind].expand_as(xs).to(xs.device) |
| | | return mask |
| | | |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_list: Union[List, str], |
| | | ): |
| | | check_argument_types() |
| | | |
| | | # self.token_list = self.load_token(token_path) |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | self.token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | # @staticmethod |
| | | # def load_token(file_path: Union[Path, str]) -> List: |
| | | # if not Path(file_path).exists(): |
| | | # raise TokenIDConverterError(f'The {file_path} does not exist.') |
| | | # |
| | | # with open(str(file_path), 'rb') as f: |
| | | # token_list = pickle.load(f) |
| | | # |
| | | # if len(token_list) != len(set(token_list)): |
| | | # raise TokenIDConverterError('The Token exists duplicated symbol.') |
| | | # return token_list |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | | token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | if self.unk_symbol not in token2id: |
| | | raise TokenIDConverterError( |
| | | f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | unk_id = token2id[self.unk_symbol] |
| | | return [token2id.get(i, unk_id) for i in tokens] |
| | | |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | |
| | | class CharTokenizer(): |
| | |
| | | space_symbol: str = "<space>", |
| | | remove_non_linguistic_symbols: bool = False, |
| | | ): |
| | | check_argument_types() |
| | | |
| | | self.space_symbol = space_symbol |
| | | self.non_linguistic_symbols = self.load_symbols(symbol_value) |
| | |
| | | if not model_path.is_file(): |
| | | raise FileExistsError(f'{model_path} is not a file.') |
| | | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | | if len(words) <= word_limit: |
| | | return [words] |
| | | sentences = [] |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | return sentences |
| | | |
| | | def code_mix_split_words(text: str): |
| | | words = [] |
| | | segs = text.split() |
| | | for seg in segs: |
| | | # There is no space in seg. |
| | | current_word = "" |
| | | for c in seg: |
| | | if len(c.encode()) == 1: |
| | | # This is an ASCII char. |
| | | current_word += c |
| | | else: |
| | | # This is a Chinese char. |
| | | if len(current_word) > 0: |
| | | words.append(current_word) |
| | | current_word = "" |
| | | words.append(c) |
| | | if len(current_word) > 0: |
| | | words.append(current_word) |
| | | return words |
| | | |
| | | def isEnglish(text:str): |
| | | if re.search('^[a-zA-Z\']+$', text): |
| | | return True |
| | | else: |
| | | return False |
| | | |
| | | def join_chinese_and_english(input_list): |
| | | line = '' |
| | | for token in input_list: |
| | | if isEnglish(token): |
| | | line = line + ' ' + token |
| | | else: |
| | | line = line + token |
| | | |
| | | line = line.strip() |
| | | return line |
| | | |
| | | def code_mix_split_words_jieba(seg_dict_file: str): |
| | | jieba.load_userdict(seg_dict_file) |
| | | |
| | | def _fn(text: str): |
| | | input_list = text.split() |
| | | token_list_all = [] |
| | | langauge_list = [] |
| | | token_list_tmp = [] |
| | | language_flag = None |
| | | for token in input_list: |
| | | if isEnglish(token) and language_flag == 'Chinese': |
| | | token_list_all.append(token_list_tmp) |
| | | langauge_list.append('Chinese') |
| | | token_list_tmp = [] |
| | | elif not isEnglish(token) and language_flag == 'English': |
| | | token_list_all.append(token_list_tmp) |
| | | langauge_list.append('English') |
| | | token_list_tmp = [] |
| | | |
| | | token_list_tmp.append(token) |
| | | |
| | | if isEnglish(token): |
| | | language_flag = 'English' |
| | | else: |
| | | language_flag = 'Chinese' |
| | | |
| | | if token_list_tmp: |
| | | token_list_all.append(token_list_tmp) |
| | | langauge_list.append(language_flag) |
| | | |
| | | result_list = [] |
| | | for token_list_tmp, language_flag in zip(token_list_all, langauge_list): |
| | | if language_flag == 'English': |
| | | result_list.extend(token_list_tmp) |
| | | else: |
| | | seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False) |
| | | result_list.extend(seg_list) |
| | | |
| | | return result_list |
| | | return _fn |
| | | |
| | | def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| | | if not Path(yaml_path).exists(): |
| | |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='rapdi_paraformer'): |
| | | def get_logger(name='funasr_onnx'): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| | |
| | | logger.addHandler(sh) |
| | | logger_initialized[name] = True |
| | | logger.propagate = False |
| | | logging.basicConfig(level=logging.ERROR) |
| | | return logger |