zhifu gao
2023-03-10 7e996b16f1cc109e6f2d68af893b2bfa3f73a073
Merge pull request #196 from alibaba-damo-academy/dev_lhn

support mfcca infenence
3个文件已修改
24 ■■■■ 已修改文件
funasr/bin/asr_inference_mfcca.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/iterable_dataset.py 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_mfcca.py
@@ -534,6 +534,8 @@
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            fs=fs,
            mc=True,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
funasr/datasets/iterable_dataset.py
@@ -66,7 +66,7 @@
    return load_bytes(bytes)
DATA_TYPES = {
    "sound": lambda x: torchaudio.load(x)[0][0].numpy(),
    "sound": lambda x: torchaudio.load(x)[0].numpy(),
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
@@ -106,6 +106,7 @@
            ] = None,
            float_dtype: str = "float32",
            fs: dict = None,
            mc: bool = False,
            int_dtype: str = "long",
            key_file: str = None,
    ):
@@ -122,6 +123,7 @@
        self.int_dtype = int_dtype
        self.key_file = key_file
        self.fs = fs
        self.mc = mc
        self.debug_info = {}
        non_iterable_list = []
@@ -192,6 +194,7 @@
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                       new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
                data[name] = array
                if self.preprocess is not None:
@@ -238,11 +241,12 @@
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = array.unsqueeze(0)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                               new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
                data[name] = array
                if self.mc:
                    data[name] = array.transpose(0, 1).numpy()
                else:
                    data[name] = array[0].numpy()
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
@@ -340,10 +344,14 @@
                        model_fs = self.fs["model_fs"]
                        if audio_fs is not None and model_fs is not None:
                            array = torch.from_numpy(array)
                            array = array.unsqueeze(0)
                            array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                                   new_freq=model_fs)(array)
                            array = array.squeeze(0).numpy()
                    if _type == "sound":
                        if self.mc:
                            data[name] = array.transpose(0, 1).numpy()
                        else:
                            data[name] = array[0].numpy()
                    else:
                    data[name] = array
                if self.non_iterable_dataset is not None:
                    # 2.b. Load data from non-iterable dataset
funasr/tasks/abs_task.py
@@ -1847,6 +1847,7 @@
            key_file: str = None,
            batch_size: int = 1,
            fs: dict = None,
            mc: bool = False,
            dtype: str = np.float32,
            num_workers: int = 1,
            allow_variable_data_keys: bool = False,
@@ -1865,6 +1866,7 @@
            data_path_and_name_and_type,
            float_dtype=dtype,
            fs=fs,
            mc=mc,
            preprocess=preprocess_fn,
            key_file=key_file,
        )