From e677eb4b13b5388f4351a164a991cea950773a72 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 26 六月 2023 17:13:22 +0800
Subject: [PATCH] fix torchaudio load mp3 bug

---
 funasr/datasets/large_datasets/dataset.py |    9 ++++++++-
 funasr/utils/prepare_data.py              |    7 ++++++-
 funasr/utils/wav_utils.py                 |   13 +++++++++++--
 funasr/datasets/iterable_dataset.py       |    9 ++++++++-
 funasr/bin/asr_inference_launch.py        |    6 +++++-
 funasr/utils/asr_utils.py                 |    6 +++++-
 6 files changed, 43 insertions(+), 7 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 656a965..ce1f984 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -19,6 +19,7 @@
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 import yaml
 from typeguard import check_argument_types
 
@@ -863,7 +864,10 @@
             raw_inputs = _load_bytes(data_path_and_name_and_type[0])
             raw_inputs = torch.tensor(raw_inputs)
         if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
-            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            try:
+                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            except:
+                raw_inputs = torch.tensor(soundfile.read(data_path_and_name_and_type[0])[0])
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, np.ndarray):
                 raw_inputs = torch.tensor(raw_inputs)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 4b2fb1a..fa0f0c7 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 from torch.utils.data.dataset import IterableDataset
 from typeguard import check_argument_types
 import os.path
@@ -66,8 +67,14 @@
         bytes = f.read()
     return load_bytes(bytes)
 
+def load_wav(input):
+    try:
+        return torchaudio.load(input)[0].numpy()
+    except:
+        return np.expand_dims(soundfile.read(input)[0], axis=0)
+
 DATA_TYPES = {
-    "sound": lambda x: torchaudio.load(x)[0].numpy(),
+    "sound": load_wav,
     "pcm": load_pcm,
     "kaldi_ark": load_kaldi,
     "bytes": load_bytes,
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 68b63e1..844dde7 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
 import torch
 import torch.distributed as dist
 import torchaudio
+import numpy as np
+import soundfile
 from kaldiio import ReadHelper
 from torch.utils.data import IterableDataset
 
@@ -123,7 +125,12 @@
                             sample_dict["key"] = key
                     elif data_type == "sound":
                         key, path = item.strip().split()
-                        waveform, sampling_rate = torchaudio.load(path)
+                        try:
+                            waveform, sampling_rate = torchaudio.load(path)
+                        except:
+                            waveform, sampling_rate = soundfile.read(path)
+                            waveform = np.expand_dims(waveform, axis=0)
+                            waveform = torch.tensor(waveform)
                         if self.frontend_conf is not None:
                             if sampling_rate != self.frontend_conf["fs"]:
                                 waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
index 4067b04..5aa40ec 100644
--- a/funasr/utils/asr_utils.py
+++ b/funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
 from typing import Any, Dict, List, Union
 
 import torchaudio
+import soundfile
 import numpy as np
 import pkg_resources
 from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
                 if support_audio_type == "pcm":
                     fs = None
                 else:
-                    audio, fs = torchaudio.load(fname)
+                    try:
+                        audio, fs = torchaudio.load(fname)
+                    except:
+                        audio, fs = soundfile.read(fname)
                 break
         if audio_type.rfind(".scp") >= 0:
             with open(fname, encoding="utf-8") as f:
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
index 7602740..0e773bb 100644
--- a/funasr/utils/prepare_data.py
+++ b/funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
 import numpy as np
 import torch.distributed as dist
 import torchaudio
+import soundfile
 
 
 def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
 
 
 def wav2num_frame(wav_path, frontend_conf):
-    waveform, sampling_rate = torchaudio.load(wav_path)
+    try:
+        waveform, sampling_rate = torchaudio.load(wav_path)
+    except:
+        waveform, sampling_rate = soundfile.read(wav_path)
+        waveform = np.expand_dims(waveform, axis=0)
     n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
     feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
     return n_frames, feature_dim
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
index ebb80d2..a6e394f 100644
--- a/funasr/utils/wav_utils.py
+++ b/funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 import torchaudio.compliance.kaldi as kaldi
 
 
@@ -162,7 +163,11 @@
         waveform = torch.from_numpy(waveform.reshape(1, -1))
     else:
         # load pcm from wav, and resample
-        waveform, audio_sr = torchaudio.load(wav_file)
+        try:
+            waveform, audio_sr = torchaudio.load(wav_file)
+        except:
+            waveform, audio_sr = soundfile.read(wav_file)
+            waveform = torch.tensor(np.expand_dims(waveform, axis=0))
         waveform = waveform * (1 << 15)
         waveform = torch_resample(waveform, audio_sr, model_sr)
 
@@ -181,7 +186,11 @@
 
 
 def wav2num_frame(wav_path, frontend_conf):
-    waveform, sampling_rate = torchaudio.load(wav_path)
+    try:
+        waveform, audio_sr = torchaudio.load(wav_file)
+    except:
+        waveform, audio_sr = soundfile.read(wav_file)
+        waveform = torch.tensor(np.expand_dims(waveform, axis=0))
     speech_length = (waveform.shape[1] / sampling_rate) * 1000.
     n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
     feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]

--
Gitblit v1.9.1