From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/eend/utils/kaldi_data.py |   80 +++++++++++++++++----------------------
 1 files changed, 35 insertions(+), 45 deletions(-)

diff --git a/funasr/models/eend/utils/kaldi_data.py b/funasr/models/eend/utils/kaldi_data.py
index 53f6230..59e7a16 100644
--- a/funasr/models/eend/utils/kaldi_data.py
+++ b/funasr/models/eend/utils/kaldi_data.py
@@ -15,16 +15,14 @@
 
 
 def load_segments(segments_file):
-    """ load segments file as array """
+    """load segments file as array"""
     if not os.path.exists(segments_file):
         return None
     return np.loadtxt(
-            segments_file,
-            dtype=[('utt', 'object'),
-                   ('rec', 'object'),
-                   ('st', 'f'),
-                   ('et', 'f')],
-            ndmin=1)
+        segments_file,
+        dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")],
+        ndmin=1,
+    )
 
 
 def load_segments_hash(segments_file):
@@ -45,35 +43,33 @@
         utt, rec, st, et = line.strip().split()
         if rec not in ret:
             ret[rec] = []
-        ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
+        ret[rec].append({"utt": utt, "st": float(st), "et": float(et)})
     return ret
 
 
 def load_wav_scp(wav_scp_file):
-    """ return dictionary { rec: wav_rxfilename } """
+    """return dictionary { rec: wav_rxfilename }"""
     lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
     return {x[0]: x[1] for x in lines}
 
 
 @lru_cache(maxsize=1)
 def load_wav(wav_rxfilename, start=0, end=None):
-    """ This function reads audio file and return data in numpy.float32 array.
-        "lru_cache" holds recently loaded audio so that can be called
-        many times on the same audio file.
-        OPTIMIZE: controls lru_cache size for random access,
-        considering memory size
+    """This function reads audio file and return data in numpy.float32 array.
+    "lru_cache" holds recently loaded audio so that can be called
+    many times on the same audio file.
+    OPTIMIZE: controls lru_cache size for random access,
+    considering memory size
     """
-    if wav_rxfilename.endswith('|'):
+    if wav_rxfilename.endswith("|"):
         # input piped command
-        p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
-                             stdout=subprocess.PIPE)
-        data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
-                                   dtype='float32')
+        p = subprocess.Popen(wav_rxfilename[:-1], shell=True, stdout=subprocess.PIPE)
+        data, samplerate = sf.load(io.BytesIO(p.stdout.read()), dtype="float32")
         # cannot seek
         data = data[start:end]
-    elif wav_rxfilename == '-':
+    elif wav_rxfilename == "-":
         # stdin
-        data, samplerate = sf.load(sys.stdin, dtype='float32')
+        data, samplerate = sf.load(sys.stdin, dtype="float32")
         # cannot seek
         data = data[start:end]
     else:
@@ -83,13 +79,13 @@
 
 
 def load_utt2spk(utt2spk_file):
-    """ returns dictionary { uttid: spkid } """
+    """returns dictionary { uttid: spkid }"""
     lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
     return {x[0]: x[1] for x in lines}
 
 
 def load_spk2utt(spk2utt_file):
-    """ returns dictionary { spkid: list of uttids } """
+    """returns dictionary { spkid: list of uttids }"""
     if not os.path.exists(spk2utt_file):
         return None
     lines = [line.strip().split() for line in open(spk2utt_file)]
@@ -97,7 +93,7 @@
 
 
 def load_reco2dur(reco2dur_file):
-    """ returns dictionary { recid: duration }  """
+    """returns dictionary { recid: duration }"""
     if not os.path.exists(reco2dur_file):
         return None
     lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
@@ -105,7 +101,7 @@
 
 
 def process_wav(wav_rxfilename, process):
-    """ This function returns preprocessed wav_rxfilename
+    """This function returns preprocessed wav_rxfilename
     Args:
         wav_rxfilename: input
         process: command which can be connected via pipe,
@@ -113,7 +109,7 @@
     Returns:
         wav_rxfilename: output piped command
     """
-    if wav_rxfilename.endswith('|'):
+    if wav_rxfilename.endswith("|"):
         # input piped command
         return wav_rxfilename + process + "|"
     else:
@@ -122,18 +118,18 @@
 
 
 def extract_segments(wavs, segments=None):
-    """ This function returns generator of segmented audio as
-        (utterance id, numpy.float32 array)
-        TODO?: sampling rate is not converted.
+    """This function returns generator of segmented audio as
+    (utterance id, numpy.float32 array)
+    TODO?: sampling rate is not converted.
     """
     if segments is not None:
         # segments should be sorted by rec-id
         for seg in segments:
-            wav = wavs[seg['rec']]
+            wav = wavs[seg["rec"]]
             data, samplerate = load_wav(wav)
-            st_sample = np.rint(seg['st'] * samplerate).astype(int)
-            et_sample = np.rint(seg['et'] * samplerate).astype(int)
-            yield seg['utt'], data[st_sample:et_sample]
+            st_sample = np.rint(seg["st"] * samplerate).astype(int)
+            et_sample = np.rint(seg["et"] * samplerate).astype(int)
+            yield seg["utt"], data[st_sample:et_sample]
     else:
         # segments file not found,
         # wav.scp is used as segmented audio list
@@ -145,18 +141,12 @@
 class KaldiData:
     def __init__(self, data_dir):
         self.data_dir = data_dir
-        self.segments = load_segments_rechash(
-                os.path.join(self.data_dir, 'segments'))
-        self.utt2spk = load_utt2spk(
-                os.path.join(self.data_dir, 'utt2spk'))
-        self.wavs = load_wav_scp(
-                os.path.join(self.data_dir, 'wav.scp'))
-        self.reco2dur = load_reco2dur(
-                os.path.join(self.data_dir, 'reco2dur'))
-        self.spk2utt = load_spk2utt(
-                os.path.join(self.data_dir, 'spk2utt'))
+        self.segments = load_segments_rechash(os.path.join(self.data_dir, "segments"))
+        self.utt2spk = load_utt2spk(os.path.join(self.data_dir, "utt2spk"))
+        self.wavs = load_wav_scp(os.path.join(self.data_dir, "wav.scp"))
+        self.reco2dur = load_reco2dur(os.path.join(self.data_dir, "reco2dur"))
+        self.spk2utt = load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
 
     def load_wav(self, recid, start=0, end=None):
-        data, rate = load_wav(
-            self.wavs[recid], start, end)
+        data, rate = load_wav(self.wavs[recid], start, end)
         return data, rate

--
Gitblit v1.9.1