zhifu gao
2024-05-08 b1c186fd00fef54bcad3aa1d073a1a313642d641
funasr/auto/auto_model.py
@@ -26,6 +26,7 @@
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils import export_utils
from funasr.utils import misc
try:
    from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
@@ -35,14 +36,7 @@
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
    """
    :param input:
    :param input_len:
    :param data_type:
    :param frontend:
    :return:
    """
    """ """
    data_list = []
    key_list = []
    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
@@ -73,7 +67,8 @@
                    key_list.append(key)
        else:
            if key is None:
                key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                key = misc.extract_filename_without_extension(data_in)
            data_list = [data_in]
            key_list = [key]
    elif isinstance(data_in, (list, tuple)):
@@ -90,10 +85,14 @@
        else:
            # [audio sample point, fbank, text]
            data_list = data_in
            key_list = [
                "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                for _ in range(len(data_in))
            ]
            key_list = []
            for data_i in data_in:
                if isinstance(data_i, str) and os.path.exists(data_i):
                    key = misc.extract_filename_without_extension(data_i)
                else:
                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                key_list.append(key)
    else:  # raw text; audio sample point, fbank; bytes
        if isinstance(data_in, bytes):  # audio bytes
            data_in = load_bytes(data_in)
@@ -108,6 +107,10 @@
class AutoModel:
    def __init__(self, **kwargs):
        log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
        logging.basicConfig(level=log_level)
        if not kwargs.get("disable_log", True):
            tables.print()