From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/fileio/sound_scp.py |   87 +++++++++++++++++++++++++++++++++++++++----
 1 files changed, 78 insertions(+), 9 deletions(-)

diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index cec7cd9..b912f1e 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -1,17 +1,84 @@
 import collections.abc
 from pathlib import Path
-from typing import Union
-from typing import Optional
+from typing import List, Tuple, Union
 
 import random
 import numpy as np
 import soundfile
 import librosa
-from typeguard import check_argument_types
 
+import torch
 import torchaudio
 
 from funasr.fileio.read_text import read_2column_text
+
+def soundfile_read(
+    wavs: Union[str, List[str]],
+    dtype=None,
+    always_2d: bool = False,
+    concat_axis: int = 1,
+    start: int = 0,
+    end: int = None,
+    return_subtype: bool = False,
+) -> Tuple[np.array, int]:
+    if isinstance(wavs, str):
+        wavs = [wavs]
+
+    arrays = []
+    subtypes = []
+    prev_rate = None
+    prev_wav = None
+    for wav in wavs:
+        with soundfile.SoundFile(wav) as f:
+            f.seek(start)
+            if end is not None:
+                frames = end - start
+            else:
+                frames = -1
+            if dtype == "float16":
+                array = f.read(
+                    frames,
+                    dtype="float32",
+                    always_2d=always_2d,
+                ).astype(dtype)
+            else:
+                array = f.read(frames, dtype=dtype, always_2d=always_2d)
+            rate = f.samplerate
+            subtype = f.subtype
+            subtypes.append(subtype)
+
+        if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
+            # array: (Time, Channel)
+            array = array[:, None]
+
+        if prev_wav is not None:
+            if prev_rate != rate:
+                raise RuntimeError(
+                    f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
+                    f"{prev_rate} != {rate}"
+                )
+
+            dim1 = arrays[0].shape[1 - concat_axis]
+            dim2 = array.shape[1 - concat_axis]
+            if dim1 != dim2:
+                raise RuntimeError(
+                    "Shapes must match with "
+                    f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
+                )
+
+        prev_rate = rate
+        prev_wav = wav
+        arrays.append(array)
+
+    if len(arrays) == 1:
+        array = arrays[0]
+    else:
+        array = np.concatenate(arrays, axis=concat_axis)
+
+    if return_subtype:
+        return array, rate, subtypes
+    else:
+        return array, rate
 
 
 class SoundScpReader(collections.abc.Mapping):
@@ -36,9 +103,8 @@
         always_2d: bool = False,
         normalize: bool = False,
         dest_sample_rate: int = 16000,
-        speed_perturb: Optional[list, tuple] = None,
+        speed_perturb: Union[list, tuple] = None,
     ):
-        assert check_argument_types()
         self.fname = fname
         self.dtype = dtype
         self.always_2d = always_2d
@@ -52,19 +118,23 @@
         if self.normalize:
             # soundfile.read normalizes data to [-1,1] if dtype is not given
             array, rate = librosa.load(
-                wav, sr=self.dest_sample_rate, mono=not self.always_2d
+                wav, sr=self.dest_sample_rate, mono=self.always_2d
             )
         else:
             array, rate = librosa.load(
-                wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype
+                wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
             )
 
         if self.speed_perturb is not None:
             speed = random.choice(self.speed_perturb)
             if speed != 1.0:
                 array, _ = torchaudio.sox_effects.apply_effects_tensor(
-                    array, rate,
+                    torch.tensor(array).view(1, -1), rate,
                     [['speed', str(speed)], ['rate', str(rate)]])
+                array = array.view(-1).numpy()
+
+        if array.ndim==2:
+            array=array.transpose((1, 0))
 
         return rate, array
 
@@ -107,7 +177,6 @@
         format="wav",
         dtype=None,
     ):
-        assert check_argument_types()
         self.dir = Path(outdir)
         self.dir.mkdir(parents=True, exist_ok=True)
         scpfile = Path(scpfile)

--
Gitblit v1.9.1