From 7e996b16f1cc109e6f2d68af893b2bfa3f73a073 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 三月 2023 10:44:16 +0800
Subject: [PATCH] Merge pull request #196 from alibaba-damo-academy/dev_lhn
---
funasr/datasets/iterable_dataset.py | 22 +++++++++++++++-------
1 files changed, 15 insertions(+), 7 deletions(-)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 2f97e78..00697fd 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -66,7 +66,7 @@
return load_bytes(bytes)
DATA_TYPES = {
- "sound": lambda x: torchaudio.load(x)[0][0].numpy(),
+ "sound": lambda x: torchaudio.load(x)[0].numpy(),
"pcm": load_pcm,
"kaldi_ark": load_kaldi,
"bytes": load_bytes,
@@ -106,6 +106,7 @@
] = None,
float_dtype: str = "float32",
fs: dict = None,
+ mc: bool = False,
int_dtype: str = "long",
key_file: str = None,
):
@@ -122,6 +123,7 @@
self.int_dtype = int_dtype
self.key_file = key_file
self.fs = fs
+ self.mc = mc
self.debug_info = {}
non_iterable_list = []
@@ -192,6 +194,7 @@
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
array = array.squeeze(0).numpy()
+
data[name] = array
if self.preprocess is not None:
@@ -238,11 +241,12 @@
model_fs = self.fs["model_fs"]
if audio_fs is not None and model_fs is not None:
array = torch.from_numpy(array)
- array = array.unsqueeze(0)
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
- array = array.squeeze(0).numpy()
- data[name] = array
+ if self.mc:
+ data[name] = array.transpose(0, 1).numpy()
+ else:
+ data[name] = array[0].numpy()
if self.preprocess is not None:
data = self.preprocess(uid, data)
@@ -340,11 +344,15 @@
model_fs = self.fs["model_fs"]
if audio_fs is not None and model_fs is not None:
array = torch.from_numpy(array)
- array = array.unsqueeze(0)
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
- array = array.squeeze(0).numpy()
- data[name] = array
+ if _type == "sound":
+ if self.mc:
+ data[name] = array.transpose(0, 1).numpy()
+ else:
+ data[name] = array[0].numpy()
+ else:
+ data[name] = array
if self.non_iterable_dataset is not None:
# 2.b. Load data from non-iterable dataset
_, from_non_iterable = self.non_iterable_dataset[uid]
--
Gitblit v1.9.1