kongdeqiang
9 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -1,40 +1,43 @@
# -*- encoding: utf-8 -*-
import functools
import yaml
import logging
import pickle
import functools
import numpy as np
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
import yaml
import warnings
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
def pad_list(xs, pad_value, max_len=None):
    n_batch = len(xs)
    if max_len is None:
        max_len = max(x.size(0) for x in xs)
    # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
    # numpy format
    pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
    for i in range(n_batch):
        pad[i, : xs[i].shape[0]] = xs[i]
class TokenIDConverter():
    def __init__(self, token_list: Union[List, str],
                 ):
    return pad
class TokenIDConverter:
    def __init__(
        self,
        token_list: Union[List, str],
    ):
        self.token_list = token_list
        self.unk_symbol = token_list[-1]
        self.token2id = {v: i for i, v in enumerate(self.token_list)}
        self.unk_id = self.token2id[self.unk_symbol]
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self,
                   integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise TokenIDConverterError(
                f"Must be 1 dim ndarray, but got {integers.ndim}")
            raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
@@ -42,7 +45,7 @@
        return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():
class CharTokenizer:
    def __init__(
        self,
        symbol_value: Union[Path, str, Iterable[str]] = None,
@@ -77,7 +80,7 @@
                if line.startswith(w):
                    if not self.remove_non_linguistic_symbols:
                        tokens.append(line[: len(w)])
                    line = line[len(w):]
                    line = line[len(w) :]
                    break
            else:
                t = line[0]
@@ -100,7 +103,6 @@
        )
class Hypothesis(NamedTuple):
    """Hypothesis data type."""
@@ -120,15 +122,15 @@
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
    if not Path(yaml_path).exists():
        raise FileExistsError(f'The {yaml_path} does not exist.')
        raise FileExistsError(f"The {yaml_path} does not exist.")
    with open(str(yaml_path), 'rb') as f:
    with open(str(yaml_path), "rb") as f:
        data = yaml.load(f, Loader=yaml.Loader)
    return data
@functools.lru_cache()
def get_logger(name='funasr_torch'):
def get_logger(name="funasr_torch"):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
@@ -148,8 +150,8 @@
            return logger
    formatter = logging.Formatter(
        '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
        datefmt="%Y/%m/%d %H:%M:%S")
        "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
    )
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)