zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/eend/utils/kaldi_data.py
@@ -20,11 +20,9 @@
        return None
    return np.loadtxt(
            segments_file,
            dtype=[('utt', 'object'),
                   ('rec', 'object'),
                   ('st', 'f'),
                   ('et', 'f')],
            ndmin=1)
        dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")],
        ndmin=1,
    )
def load_segments_hash(segments_file):
@@ -45,7 +43,7 @@
        utt, rec, st, et = line.strip().split()
        if rec not in ret:
            ret[rec] = []
        ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
        ret[rec].append({"utt": utt, "st": float(st), "et": float(et)})
    return ret
@@ -63,17 +61,15 @@
        OPTIMIZE: controls lru_cache size for random access,
        considering memory size
    """
    if wav_rxfilename.endswith('|'):
    if wav_rxfilename.endswith("|"):
        # input piped command
        p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
                             stdout=subprocess.PIPE)
        data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
                                   dtype='float32')
        p = subprocess.Popen(wav_rxfilename[:-1], shell=True, stdout=subprocess.PIPE)
        data, samplerate = sf.load(io.BytesIO(p.stdout.read()), dtype="float32")
        # cannot seek
        data = data[start:end]
    elif wav_rxfilename == '-':
    elif wav_rxfilename == "-":
        # stdin
        data, samplerate = sf.load(sys.stdin, dtype='float32')
        data, samplerate = sf.load(sys.stdin, dtype="float32")
        # cannot seek
        data = data[start:end]
    else:
@@ -113,7 +109,7 @@
    Returns:
        wav_rxfilename: output piped command
    """
    if wav_rxfilename.endswith('|'):
    if wav_rxfilename.endswith("|"):
        # input piped command
        return wav_rxfilename + process + "|"
    else:
@@ -129,11 +125,11 @@
    if segments is not None:
        # segments should be sorted by rec-id
        for seg in segments:
            wav = wavs[seg['rec']]
            wav = wavs[seg["rec"]]
            data, samplerate = load_wav(wav)
            st_sample = np.rint(seg['st'] * samplerate).astype(int)
            et_sample = np.rint(seg['et'] * samplerate).astype(int)
            yield seg['utt'], data[st_sample:et_sample]
            st_sample = np.rint(seg["st"] * samplerate).astype(int)
            et_sample = np.rint(seg["et"] * samplerate).astype(int)
            yield seg["utt"], data[st_sample:et_sample]
    else:
        # segments file not found,
        # wav.scp is used as segmented audio list
@@ -145,18 +141,12 @@
class KaldiData:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.segments = load_segments_rechash(
                os.path.join(self.data_dir, 'segments'))
        self.utt2spk = load_utt2spk(
                os.path.join(self.data_dir, 'utt2spk'))
        self.wavs = load_wav_scp(
                os.path.join(self.data_dir, 'wav.scp'))
        self.reco2dur = load_reco2dur(
                os.path.join(self.data_dir, 'reco2dur'))
        self.spk2utt = load_spk2utt(
                os.path.join(self.data_dir, 'spk2utt'))
        self.segments = load_segments_rechash(os.path.join(self.data_dir, "segments"))
        self.utt2spk = load_utt2spk(os.path.join(self.data_dir, "utt2spk"))
        self.wavs = load_wav_scp(os.path.join(self.data_dir, "wav.scp"))
        self.reco2dur = load_reco2dur(os.path.join(self.data_dir, "reco2dur"))
        self.spk2utt = load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
    def load_wav(self, recid, start=0, end=None):
        data, rate = load_wav(
            self.wavs[recid], start, end)
        data, rate = load_wav(self.wavs[recid], start, end)
        return data, rate