From 842df33fa23331e819965324df5d9e790eccbf9f Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期四, 16 二月 2023 15:58:05 +0800
Subject: [PATCH] add scripts for simu data
---
egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py | 119 +++++++++++++++++++++++++++++++++++++++
egs/mars/sd/scripts/dump_rttm_to_labels.py | 10 ++-
2 files changed, 126 insertions(+), 3 deletions(-)
diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py
index d3a58a2..09ea723 100644
--- a/egs/mars/sd/scripts/dump_rttm_to_labels.py
+++ b/egs/mars/sd/scripts/dump_rttm_to_labels.py
@@ -37,18 +37,22 @@
common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
logging.warning("Keep {} records.".format(len(common_keys)))
new_meeting_scp = OrderedDict()
+ rm_keys = []
for key in meeting_scp:
if key not in common_keys:
- logging.warning("Pop {} from wav scp".format(key))
+ rm_keys.append(key)
else:
new_meeting_scp[key] = meeting_scp[key]
+ logging.warning("Keys are removed from wav scp:", " ".format(rm_keys))
+
new_meeting2rttm = OrderedDict()
+ rm_keys = []
for key in meeting2rttm:
if key not in common_keys:
- logging.warning("Pop {} from rttm scp".format(key))
+ rm_keys.append(key)
else:
new_meeting2rttm[key] = meeting2rttm[key]
-
+ logging.warning("Keys are removed from rttm scp:", " ".format(rm_keys))
meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
new file mode 100644
index 0000000..56ad787
--- /dev/null
+++ b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
@@ -0,0 +1,119 @@
+import numpy as np
+import os
+import sys
+import argparse
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import librosa
+import soundfile as sf
+from copy import deepcopy
+import json
+from tqdm import tqdm
+
+
+class MyRunner(MultiProcessRunnerV3):
+ def prepare(self, parser):
+ assert isinstance(parser, argparse.ArgumentParser)
+ parser.add_argument("wav_scp", type=str)
+ parser.add_argument("rttm", type=str)
+ parser.add_argument("out_dir", type=str)
+ parser.add_argument("--min_dur", type=float, default=2.0)
+ parser.add_argument("--max_spk_num", type=int, default=4)
+ args = parser.parse_args()
+
+ if not os.path.exists(args.out_dir):
+ os.makedirs(args.out_dir)
+
+ wav_scp = load_scp_as_list(args.wav_scp)
+ meeting2rttms = {}
+ for one_line in open(args.rttm, "rt"):
+ parts = [x for x in one_line.strip().split(" ") if x != ""]
+ mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
+ if mid not in meeting2rttms:
+ meeting2rttms[mid] = []
+ meeting2rttms[mid].append(one_line)
+
+ task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
+ return task_list, None, args
+
+ def post(self, result_list, args):
+ count = [0, 0]
+ for result in result_list:
+ count[0] += result[0]
+ count[1] += result[1]
+ print("Found {} speakers, extracted {}.".format(count[1], count[0]))
+
+
+# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
+def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
+ labels = np.zeros([max_spk_num, length], int)
+ spk_list = []
+ for one_line in rttms:
+ parts = [x for x in one_line.strip().split(" ") if x != ""]
+ mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
+ spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
+ if spk_name.isdigit():
+ spk_name = "{}_S{:03d}".format(mid, int(spk_name))
+ else:
+ spk_name = "{}_{}".format(mid, spk_name)
+ if spk_name not in spk_list:
+ spk_list.append(spk_name)
+ st, dur = int(st*sr), int(dur*sr)
+ idx = spk_list.index(spk_name)
+ labels[idx, st:st+dur] = 1
+ return labels, spk_list
+
+
+def get_nonoverlap_turns(multi_label, spk_list):
+ turns = []
+ label = np.sum(multi_label, axis=0) == 1
+ spk, in_turn, st = None, False, 0
+ for i in range(len(label)):
+ if not in_turn and label[i]:
+ st, in_turn = i, True
+ spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
+ if in_turn:
+ if not label[i]:
+ in_turn = False
+ turns.append([st, i, spk])
+ elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
+ turns.append([st, i, spk])
+ st, in_turn = i, True
+ spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
+ if in_turn:
+ turns.append([st, len(label), spk])
+ return turns
+
+
+def process(task_args):
+ task_id, task_list, _, args = task_args
+ spk_count = [0, 0]
+ for mid, wav_path, rttms in task_list:
+ wav, sr = sf.read(wav_path, dtype="int16")
+ assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
+ multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
+ turns = get_nonoverlap_turns(multi_label, spk_list)
+ extracted_spk = []
+ count = 1
+ for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
+ if (ed - st) >= args.min_dur * args.sr:
+ seg = wav[st: ed]
+ save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
+ if not os.path.exists(os.path.dirname(save_path)):
+ os.makedirs(os.path.dirname(save_path))
+ sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
+ count += 1
+ if spk not in extracted_spk:
+ extracted_spk.append(spk)
+ if len(extracted_spk) != len(spk_list):
+ print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
+ mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
+ ))
+ spk_count[0] += len(extracted_spk)
+ spk_count[1] += len(spk_list)
+ return spk_count
+
+
+if __name__ == '__main__':
+ my_runner = MyRunner(process)
+ my_runner.run()
--
Gitblit v1.9.1