From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages
---
funasr/datasets/iterable_dataset.py | 35 +++++++++++++++++++----------------
1 files changed, 19 insertions(+), 16 deletions(-)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 49c7068..6398e0c 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -8,13 +8,14 @@
from typing import Iterator
from typing import Tuple
from typing import Union
+from typing import List
import kaldiio
import numpy as np
import torch
import torchaudio
+import soundfile
from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
@@ -65,8 +66,17 @@
bytes = f.read()
return load_bytes(bytes)
+def load_wav(input):
+ try:
+ return torchaudio.load(input)[0].numpy()
+ except:
+ waveform, _ = soundfile.read(input, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ return np.expand_dims(waveform, axis=0)
+
DATA_TYPES = {
- "sound": lambda x: torchaudio.load(x)[0].numpy(),
+ "sound": load_wav,
"pcm": load_pcm,
"kaldi_ark": load_kaldi,
"bytes": load_bytes,
@@ -110,7 +120,6 @@
int_dtype: str = "long",
key_file: str = None,
):
- assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
@@ -129,7 +138,7 @@
non_iterable_list = []
self.path_name_type_list = []
- if not isinstance(path_name_type_list[0], Tuple):
+ if not isinstance(path_name_type_list[0], (Tuple, List)):
path = path_name_type_list[0]
name = path_name_type_list[1]
_type = path_name_type_list[2]
@@ -227,13 +236,9 @@
name = self.path_name_type_list[i][1]
_type = self.path_name_type_list[i][2]
if _type == "sound":
- audio_type = os.path.basename(value).split(".")[-1].lower()
- if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
- raise NotImplementedError(
- f'Not supported audio type: {audio_type}')
- if audio_type == "pcm":
- _type = "pcm"
-
+ audio_type = os.path.basename(value).lower()
+ if audio_type.rfind(".pcm") >= 0:
+ _type = "pcm"
func = DATA_TYPES[_type]
array = func(value)
if self.fs is not None and (name == "speech" or name == "ref_speech"):
@@ -335,11 +340,8 @@
# 2.a. Load data streamingly
for value, (path, name, _type) in zip(values, self.path_name_type_list):
if _type == "sound":
- audio_type = os.path.basename(value).split(".")[-1].lower()
- if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
- raise NotImplementedError(
- f'Not supported audio type: {audio_type}')
- if audio_type == "pcm":
+ audio_type = os.path.basename(value).lower()
+ if audio_type.rfind(".pcm") >= 0:
_type = "pcm"
func = DATA_TYPES[_type]
# Load entry
@@ -391,3 +393,4 @@
if count == 0:
raise RuntimeError("No iteration")
+
--
Gitblit v1.9.1