haoneng.lhn
2023-06-26 e677eb4b13b5388f4351a164a991cea950773a72
fix torchaudio load mp3 bug
6个文件已修改
40 ■■■■■ 已修改文件
funasr/bin/asr_inference_launch.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/iterable_dataset.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/dataset.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/asr_utils.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/prepare_data.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/wav_utils.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
import yaml
from typeguard import check_argument_types
@@ -863,7 +864,10 @@
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            try:
            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            except:
                raw_inputs = torch.tensor(soundfile.read(data_path_and_name_and_type[0])[0])
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
@@ -66,8 +67,14 @@
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        return np.expand_dims(soundfile.read(input)[0], axis=0)
DATA_TYPES = {
    "sound": lambda x: torchaudio.load(x)[0].numpy(),
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
import torch
import torch.distributed as dist
import torchaudio
import numpy as np
import soundfile
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -123,7 +125,12 @@
                            sample_dict["key"] = key
                    elif data_type == "sound":
                        key, path = item.strip().split()
                        try:
                        waveform, sampling_rate = torchaudio.load(path)
                        except:
                            waveform, sampling_rate = soundfile.read(path)
                            waveform = np.expand_dims(waveform, axis=0)
                            waveform = torch.tensor(waveform)
                        if self.frontend_conf is not None:
                            if sampling_rate != self.frontend_conf["fs"]:
                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
from typing import Any, Dict, List, Union
import torchaudio
import soundfile
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
                if support_audio_type == "pcm":
                    fs = None
                else:
                    try:
                    audio, fs = torchaudio.load(fname)
                    except:
                        audio, fs = soundfile.read(fname)
                break
        if audio_type.rfind(".scp") >= 0:
            with open(fname, encoding="utf-8") as f:
funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
import numpy as np
import torch.distributed as dist
import torchaudio
import soundfile
def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
def wav2num_frame(wav_path, frontend_conf):
    try:
    waveform, sampling_rate = torchaudio.load(wav_path)
    except:
        waveform, sampling_rate = soundfile.read(wav_path)
        waveform = np.expand_dims(waveform, axis=0)
    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
    return n_frames, feature_dim
funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
import torchaudio.compliance.kaldi as kaldi
@@ -162,7 +163,11 @@
        waveform = torch.from_numpy(waveform.reshape(1, -1))
    else:
        # load pcm from wav, and resample
        try:
        waveform, audio_sr = torchaudio.load(wav_file)
        except:
            waveform, audio_sr = soundfile.read(wav_file)
            waveform = torch.tensor(np.expand_dims(waveform, axis=0))
        waveform = waveform * (1 << 15)
        waveform = torch_resample(waveform, audio_sr, model_sr)
@@ -181,7 +186,11 @@
def wav2num_frame(wav_path, frontend_conf):
    waveform, sampling_rate = torchaudio.load(wav_path)
    try:
        waveform, audio_sr = torchaudio.load(wav_file)
    except:
        waveform, audio_sr = soundfile.read(wav_file)
        waveform = torch.tensor(np.expand_dims(waveform, axis=0))
    speech_length = (waveform.shape[1] / sampling_rate) * 1000.
    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]