语帆
2024-02-23 6a8c943435edf25f252d9d4db0095d4a01c7a3cd
test
2个文件已修改
17 ■■■■■ 已修改文件
funasr/models/lcbnet/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/load_utils.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/model.py
@@ -425,6 +425,7 @@
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            pdb.set_trace()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
funasr/utils/load_utils.py
@@ -13,30 +13,34 @@
    from funasr.download.file import download_from_url
except:
    print("urllib is not installed, if you infer from url, please install it first.")
import pdb
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
    pdb.set_trace()
    if isinstance(data_or_path_or_list, (list, tuple)):
        if data_type is not None and isinstance(data_type, (list, tuple)):
            pdb.set_trace()
            data_types = [data_type] * len(data_or_path_or_list)
            data_or_path_or_list_ret = [[] for d in data_type]
            pdb.set_trace()
            for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
                
                for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
                    pdb.set_trace()
                    data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
                    pdb.set_trace()
                    data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
            return data_or_path_or_list_ret
        else:
            return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
    pdb.set_trace()
    if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
        data_or_path_or_list = download_from_url(data_or_path_or_list)
    pdb.set_trace()
    if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
        pdb.set_trace()
        if data_type is None or data_type == "sound":
            data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
            if kwargs.get("reduce_channels", True):
@@ -59,7 +63,7 @@
    else:
        pass
        # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
    pdb.set_trace()
    if audio_fs != fs and data_type != "text":
        resampler = torchaudio.transforms.Resample(audio_fs, fs)
        data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]