| | |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import functools |
| | | import yaml |
| | | import logging |
| | | import pickle |
| | | import functools |
| | | import numpy as np |
| | | from pathlib import Path |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | |
| | | import numpy as np |
| | | import yaml |
| | | |
| | | |
| | | import warnings |
| | | |
| | | root_dir = Path(__file__).resolve().parent |
| | | |
| | | logger_initialized = {} |
| | | |
| | | def pad_list(xs, pad_value, max_len=None): |
| | | n_batch = len(xs) |
| | | if max_len is None: |
| | | max_len = max(x.size(0) for x in xs) |
| | | # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) |
| | | # numpy format |
| | | pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32) |
| | | for i in range(n_batch): |
| | | pad[i, : xs[i].shape[0]] = xs[i] |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_list: Union[List, str], |
| | | ): |
| | | return pad |
| | | |
| | | class TokenIDConverter: |
| | | def __init__( |
| | | self, |
| | | token_list: Union[List, str], |
| | | ): |
| | | |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | self.token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | | |
| | | def ids2tokens(self, |
| | | integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | if isinstance(integers, np.ndarray) and integers.ndim != 1: |
| | | raise TokenIDConverterError( |
| | | f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | |
| | | class CharTokenizer(): |
| | | class CharTokenizer: |
| | | def __init__( |
| | | self, |
| | | symbol_value: Union[Path, str, Iterable[str]] = None, |
| | |
| | | if line.startswith(w): |
| | | if not self.remove_non_linguistic_symbols: |
| | | tokens.append(line[: len(w)]) |
| | | line = line[len(w):] |
| | | line = line[len(w) :] |
| | | break |
| | | else: |
| | | t = line[0] |
| | |
| | | ) |
| | | |
| | | |
| | | |
| | | class Hypothesis(NamedTuple): |
| | | """Hypothesis data type.""" |
| | | |
| | |
| | | |
| | | def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| | | if not Path(yaml_path).exists(): |
| | | raise FileExistsError(f'The {yaml_path} does not exist.') |
| | | raise FileExistsError(f"The {yaml_path} does not exist.") |
| | | |
| | | with open(str(yaml_path), 'rb') as f: |
| | | with open(str(yaml_path), "rb") as f: |
| | | data = yaml.load(f, Loader=yaml.Loader) |
| | | return data |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='funasr_torch'): |
| | | def get_logger(name="funasr_torch"): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| | |
| | | return logger |
| | | |
| | | formatter = logging.Formatter( |
| | | '[%(asctime)s] %(name)s %(levelname)s: %(message)s', |
| | | datefmt="%Y/%m/%d %H:%M:%S") |
| | | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S" |
| | | ) |
| | | |
| | | sh = logging.StreamHandler() |
| | | sh.setFormatter(formatter) |