From 0109889f1cbbd7ff703383bfacb204d45f5d37a9 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 21 七月 2023 02:45:54 +0800
Subject: [PATCH] update

---
 egs/callhome/eend_ola/run_test.sh           |   23 +++++--
 egs/callhome/eend_ola/local/dump_feature.py |  105 ++++++++++++++++++++--------------
 2 files changed, 77 insertions(+), 51 deletions(-)

diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py
index 169615e..332edd2 100644
--- a/egs/callhome/eend_ola/local/dump_feature.py
+++ b/egs/callhome/eend_ola/local/dump_feature.py
@@ -1,10 +1,11 @@
 import argparse
 import os
 
-import numpy as np
+from kaldiio import WriteHelper
 
 import funasr.modules.eend_ola.utils.feature as feature
-import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data
+from funasr.modules.eend_ola.utils.kaldi_data import load_segments_rechash, load_utt2spk, load_wav_scp, load_reco2dur, \
+    load_spk2utt, load_wav
 
 
 def _count_frames(data_len, size, step):
@@ -24,10 +25,34 @@
             yield (i + 1) * step, data_length
 
 
+class KaldiData:
+    def __init__(self, data_dir, idx):
+        self.data_dir = data_dir
+        segment_file = os.path.join(self.data_dir, 'segments.{}'.format(idx))
+        self.segments = load_segments_rechash(segment_file)
+
+        utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx))
+        self.utt2spk = load_utt2spk(utt2spk_file)
+
+        wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx))
+        self.wavs = load_wav_scp(wav_file)
+
+        reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx))
+        self.reco2dur = load_reco2dur(reco2dur_file)
+
+        spk2utt_file = os.path.join(self.data_dir, 'spk2utt.{}'.format(idx))
+        self.spk2utt = load_spk2utt(spk2utt_file)
+
+    def load_wav(self, recid, start=0, end=None):
+        data, rate = load_wav(self.wavs[recid], start, end)
+        return data, rate
+
+
 class KaldiDiarizationDataset():
     def __init__(
             self,
             data_dir,
+            index,
             chunk_size=2000,
             context_size=0,
             frame_size=1024,
@@ -40,6 +65,7 @@
             n_speakers=None,
     ):
         self.data_dir = data_dir
+        self.index = index
         self.chunk_size = chunk_size
         self.context_size = context_size
         self.frame_size = frame_size
@@ -50,9 +76,8 @@
         self.chunk_indices = []
         self.label_delay = label_delay
 
-        self.data = kaldi_data.KaldiData(self.data_dir)
+        self.data = KaldiData(self.data_dir, index)
 
-        # make chunk indices: filepath, start_frame, end_frame
         for rec, path in self.data.wavs.items():
             data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
             data_len = int(data_len / self.subsampling)
@@ -66,62 +91,54 @@
 
 
 def convert(args):
-    f = open(out_wav_file, 'w')
     dataset = KaldiDiarizationDataset(
         data_dir=args.data_dir,
+        index=args.index,
         chunk_size=args.num_frames,
         context_size=args.context_size,
-        input_transform=args.input_transform,
+        input_transform="logmel23_mn",
         frame_size=args.frame_size,
         frame_shift=args.frame_shift,
         subsampling=args.subsampling,
         rate=8000,
         use_last_samples=True,
     )
-    length = len(dataset.chunk_indices)
-    for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
-        Y, T = feature.get_labeledSTFT(
-            dataset.data,
-            rec,
-            st,
-            ed,
-            dataset.frame_size,
-            dataset.frame_shift,
-            dataset.n_speakers)
-        Y = feature.transform(Y, dataset.input_transform)
-        Y_spliced = feature.splice(Y, dataset.context_size)
-        Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
-        st = '{:0>7d}'.format(st)
-        ed = '{:0>7d}'.format(ed)
-        suffix = '_' + st + '_' + ed
 
-        parts = os.readlink('/'.join(path.split('/')[:-1])).split('/')
-        # print('parts: ', parts)
-        parts = parts[:4] + ['numpy_data'] + parts[4:]
-        cur_path = '/'.join(parts)
-        # print('cur path: ', cur_path)
-        out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz')
-        # print(out_path)
-        # print(cur_path)
-        if not os.path.exists(cur_path):
-            os.makedirs(cur_path)
-        np.savez(out_path, Y=Y_ss, T=T_ss)
-        if idx == length - 1:
-            f.write(rec + suffix + ' ' + out_path)
-        else:
-            f.write(rec + suffix + ' ' + out_path + '\n')
+    feature_ark_file = os.path.join(args.output_dir, "feature.ark.{}".format(args.index))
+    feature_scp_file = os.path.join(args.output_dir, "feature.scp.{}".format(args.index))
+    label_ark_file = os.path.join(args.output_dir, "label.ark.{}".format(args.index))
+    label_scp_file = os.path.join(args.output_dir, "label.scp.{}".format(args.index))
+    with WriteHelper('ark,scp:{},{}'.format(feature_ark_file, feature_scp_file)) as feature_writer, \
+            WriteHelper('ark,scp:{},{}'.format(label_ark_file, label_scp_file)) as label_writer:
+        for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
+            Y, T = feature.get_labeledSTFT(
+                dataset.data,
+                rec,
+                st,
+                ed,
+                dataset.frame_size,
+                dataset.frame_shift,
+                dataset.n_speakers)
+            Y = feature.transform(Y, dataset.input_transform)
+            Y_spliced = feature.splice(Y, dataset.context_size)
+            Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
+            st = '{:0>7d}'.format(st)
+            ed = '{:0>7d}'.format(ed)
+            key = "{}_{}_{}".format(rec, st, ed)
+            feature_writer(key, Y_ss)
+            label_writer(key, T_ss.reshape(-1))
 
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument("data_dir")
-    parser.add_argument("num_frames")
-    parser.add_argument("context_size")
-    parser.add_argument("frame_size")
-    parser.add_argument("frame_shift")
-    parser.add_argument("subsampling")
-
-
+    parser.add_argument("output_dir")
+    parser.add_argument("index")
+    parser.add_argument("num_frames", default=500)
+    parser.add_argument("context_size", default=7)
+    parser.add_argument("frame_size", default=200)
+    parser.add_argument("frame_shift", default=80)
+    parser.add_argument("subsampling", default=10)
 
     args = parser.parse_args()
     convert(args)
diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh
index 8ba8d57..c6a3a71 100644
--- a/egs/callhome/eend_ola/run_test.sh
+++ b/egs/callhome/eend_ola/run_test.sh
@@ -78,17 +78,26 @@
     for dset in swb_sre_tr swb_sre_cv; do
         if [ "$dset" == "swb_sre_tr" ]; then
             n_mixtures=${simu_opts_num_train}
+            dataset=train
         else
             n_mixtures=500
+            dataset=dev
         fi
         simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures}
-        mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work
-        split_scps=
-        for n in $(seq $nj); do
-            split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp"
-        done
-        utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
-        python local/split.py ${data_dir}/simu/data/${simu_data_dir}
+#        mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work
+#        split_scps=
+#        for n in $(seq $nj); do
+#            split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp"
+#        done
+#        utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
+#        python local/split.py ${data_dir}/simu/data/${simu_data_dir}
+        output_dir=${data_dir}/ark_data/dump/simu_data/$dataset
+        mkdir -p $output_dir/.logs
+        $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \
+        python local/dump_feature.py \
+              --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \
+              --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \
+              --index JOB
     done
 fi
 

--
Gitblit v1.9.1