zhifu gao
2024-04-26 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa
runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -17,24 +17,23 @@
logger_initialized = {}
class TokenIDConverter():
    def __init__(self, token_list: Union[List, str],
                 ):
class TokenIDConverter:
    def __init__(
        self,
        token_list: Union[List, str],
    ):
        self.token_list = token_list
        self.unk_symbol = token_list[-1]
        self.token2id = {v: i for i, v in enumerate(self.token_list)}
        self.unk_id = self.token2id[self.unk_symbol]
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self,
                   integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise TokenIDConverterError(
                f"Must be 1 dim ndarray, but got {integers.ndim}")
            raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
@@ -42,7 +41,7 @@
        return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():
class CharTokenizer:
    def __init__(
        self,
        symbol_value: Union[Path, str, Iterable[str]] = None,
@@ -77,7 +76,7 @@
                if line.startswith(w):
                    if not self.remove_non_linguistic_symbols:
                        tokens.append(line[: len(w)])
                    line = line[len(w):]
                    line = line[len(w) :]
                    break
            else:
                t = line[0]
@@ -100,7 +99,6 @@
        )
class Hypothesis(NamedTuple):
    """Hypothesis data type."""
@@ -120,15 +118,15 @@
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
    if not Path(yaml_path).exists():
        raise FileExistsError(f'The {yaml_path} does not exist.')
        raise FileExistsError(f"The {yaml_path} does not exist.")
    with open(str(yaml_path), 'rb') as f:
    with open(str(yaml_path), "rb") as f:
        data = yaml.load(f, Loader=yaml.Loader)
    return data
@functools.lru_cache()
def get_logger(name='funasr_torch'):
def get_logger(name="funasr_torch"):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
@@ -148,8 +146,8 @@
            return logger
    formatter = logging.Formatter(
        '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
        datefmt="%Y/%m/%d %H:%M:%S")
        "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
    )
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)