From a19c9e5c805d727d2d0a81053997d70e3326e3cb Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期二, 27 六月 2023 13:18:15 +0800
Subject: [PATCH] Merge pull request #671 from alibaba-damo-academy/dev_lhn
---
funasr/datasets/large_datasets/dataset.py | 4 +++-
funasr/utils/wav_utils.py | 4 +++-
funasr/datasets/iterable_dataset.py | 5 ++++-
funasr/bin/asr_inference_launch.py | 5 ++++-
4 files changed, 14 insertions(+), 4 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index ce1f984..5d1b804 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -867,7 +867,10 @@
try:
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
except:
- raw_inputs = torch.tensor(soundfile.read(data_path_and_name_and_type[0])[0])
+ raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+ if raw_inputs.ndim == 2:
+ raw_inputs = raw_inputs[:, 0]
+ raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index fa0f0c7..d240d93 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -71,7 +71,10 @@
try:
return torchaudio.load(input)[0].numpy()
except:
- return np.expand_dims(soundfile.read(input)[0], axis=0)
+ waveform, _ = soundfile.read(input, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ return np.expand_dims(waveform, axis=0)
DATA_TYPES = {
"sound": load_wav,
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 844dde7..5f2c2c6 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -128,7 +128,9 @@
try:
waveform, sampling_rate = torchaudio.load(path)
except:
- waveform, sampling_rate = soundfile.read(path)
+ waveform, sampling_rate = soundfile.read(path, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
waveform = np.expand_dims(waveform, axis=0)
waveform = torch.tensor(waveform)
if self.frontend_conf is not None:
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
index a6e394f..e523495 100644
--- a/funasr/utils/wav_utils.py
+++ b/funasr/utils/wav_utils.py
@@ -166,7 +166,9 @@
try:
waveform, audio_sr = torchaudio.load(wav_file)
except:
- waveform, audio_sr = soundfile.read(wav_file)
+ waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
waveform = torch.tensor(np.expand_dims(waveform, axis=0))
waveform = waveform * (1 << 15)
waveform = torch_resample(waveform, audio_sr, model_sr)
--
Gitblit v1.9.1