From ab828bcf7badb228fdc59647f5c9c75e33acce9d Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 17 二月 2023 00:04:14 +0800
Subject: [PATCH] add scripts for simu data

---
 egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py |    4 --
 egs/mars/sd/scripts/simu_chunk_with_labels.py         |   83 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 83 insertions(+), 4 deletions(-)

diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
index 56ad787..cd1ec7b 100644
--- a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
+++ b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
@@ -1,13 +1,9 @@
 import numpy as np
 import os
-import sys
 import argparse
 from funasr.utils.job_runner import MultiProcessRunnerV3
 from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import librosa
 import soundfile as sf
-from copy import deepcopy
-import json
 from tqdm import tqdm
 
 
diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py
new file mode 100644
index 0000000..d1d9a2f
--- /dev/null
+++ b/egs/mars/sd/scripts/simu_chunk_with_labels.py
@@ -0,0 +1,83 @@
+import logging
+import numpy as np
+import soundfile
+import kaldiio
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import argparse
+from collections import OrderedDict
+import random
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser: argparse.ArgumentParser):
+        parser.add_argument("--label_scp", type=str, required=True)
+        parser.add_argument("--wav_scp", type=str, required=True)
+        parser.add_argument("--utt2spk", type=str, required=True)
+        parser.add_argument("--spk2meeting", type=str, required=True)
+        parser.add_argument("--utt2xvec", type=str, required=True)
+        parser.add_argument("--out_dir", type=str, required=True)
+        parser.add_argument("--chunk_size", type=int, default=16)
+        parser.add_argument("--chunk_shift", type=int, default=4)
+        parser.add_argument("--frame_shift", type=float, default=0.01)
+        args = parser.parse_args()
+
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        label_list = load_scp_as_list(args.label_scp)
+        wav_scp = load_scp_as_dict(args.wav_scp)
+        utt2spk = load_scp_as_dict(args.utt2spk)
+        utt2xvec = load_scp_as_dict(args.utt2xvec)
+        spk2meeting = load_scp_as_dict(args.spk2meeting)
+
+        meeting2spks = OrderedDict()
+        for spk, meeting in spk2meeting.items():
+            if meeting not in meeting2spks:
+                meeting2spks[meeting] = []
+            meeting2spks[meeting].append(spk)
+
+        spk2utts = OrderedDict()
+        for utt, spk in utt2spk.items():
+            if spk not in spk2utts:
+                spk2utts[spk] = []
+            spk2utts[spk].append(utt)
+
+        return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
+
+    def post(self, results_list, args):
+        pass
+
+
+def process(task_args):
+    task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
+    out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
+    wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+
+    out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
+    wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+
+    out_path = os.path.join(args.out_dir, "label.{}".format(task_idx + 1))
+    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+
+    idx = 0
+    for _, label_path in task_list:
+        rand_shift = random.randint(0, int(args.chunk_shift / args.frame_shift))
+        whole_label = kaldiio.load_mat(label_path)
+        whole_label = whole_label[rand_shift:]
+        num_chunk = (whole_label.shape[0] - args.chunk_size) // args.chunk_shift
+        for i in range(num_chunk):
+            utt_id = "part{}_utt{:10d}".format(task_idx + 1, idx + 1)
+
+
+    wav_mix_writer.close()
+    wav_sep_writer.close()
+    label_writer.close()
+    return None
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()

--
Gitblit v1.9.1