From 7e996b16f1cc109e6f2d68af893b2bfa3f73a073 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 三月 2023 10:44:16 +0800
Subject: [PATCH] Merge pull request #196 from alibaba-damo-academy/dev_lhn
---
funasr/bin/asr_inference_mfcca.py | 2 ++
funasr/tasks/abs_task.py | 2 ++
funasr/datasets/iterable_dataset.py | 22 +++++++++++++++-------
3 files changed, 19 insertions(+), 7 deletions(-)
diff --git a/funasr/bin/asr_inference_mfcca.py b/funasr/bin/asr_inference_mfcca.py
index e25b2a9..9f5cb19 100644
--- a/funasr/bin/asr_inference_mfcca.py
+++ b/funasr/bin/asr_inference_mfcca.py
@@ -534,6 +534,8 @@
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
+ fs=fs,
+ mc=True,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 2f97e78..00697fd 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -66,7 +66,7 @@
return load_bytes(bytes)
DATA_TYPES = {
- "sound": lambda x: torchaudio.load(x)[0][0].numpy(),
+ "sound": lambda x: torchaudio.load(x)[0].numpy(),
"pcm": load_pcm,
"kaldi_ark": load_kaldi,
"bytes": load_bytes,
@@ -106,6 +106,7 @@
] = None,
float_dtype: str = "float32",
fs: dict = None,
+ mc: bool = False,
int_dtype: str = "long",
key_file: str = None,
):
@@ -122,6 +123,7 @@
self.int_dtype = int_dtype
self.key_file = key_file
self.fs = fs
+ self.mc = mc
self.debug_info = {}
non_iterable_list = []
@@ -192,6 +194,7 @@
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
array = array.squeeze(0).numpy()
+
data[name] = array
if self.preprocess is not None:
@@ -238,11 +241,12 @@
model_fs = self.fs["model_fs"]
if audio_fs is not None and model_fs is not None:
array = torch.from_numpy(array)
- array = array.unsqueeze(0)
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
- array = array.squeeze(0).numpy()
- data[name] = array
+ if self.mc:
+ data[name] = array.transpose(0, 1).numpy()
+ else:
+ data[name] = array[0].numpy()
if self.preprocess is not None:
data = self.preprocess(uid, data)
@@ -340,11 +344,15 @@
model_fs = self.fs["model_fs"]
if audio_fs is not None and model_fs is not None:
array = torch.from_numpy(array)
- array = array.unsqueeze(0)
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
- array = array.squeeze(0).numpy()
- data[name] = array
+ if _type == "sound":
+ if self.mc:
+ data[name] = array.transpose(0, 1).numpy()
+ else:
+ data[name] = array[0].numpy()
+ else:
+ data[name] = array
if self.non_iterable_dataset is not None:
# 2.b. Load data from non-iterable dataset
_, from_non_iterable = self.non_iterable_dataset[uid]
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index a643acb..723a67c 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -1847,6 +1847,7 @@
key_file: str = None,
batch_size: int = 1,
fs: dict = None,
+ mc: bool = False,
dtype: str = np.float32,
num_workers: int = 1,
allow_variable_data_keys: bool = False,
@@ -1865,6 +1866,7 @@
data_path_and_name_and_type,
float_dtype=dtype,
fs=fs,
+ mc=mc,
preprocess=preprocess_fn,
key_file=key_file,
)
--
Gitblit v1.9.1