From 81123acf88ab0ef5eb6659049bb9fbb17dda5c49 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 12 五月 2023 14:48:21 +0800
Subject: [PATCH] update repo
---
/dev/null | 261 ----------------------------------------------------
1 files changed, 0 insertions(+), 261 deletions(-)
diff --git a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
deleted file mode 100644
index 459a741..0000000
--- a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
+++ /dev/null
@@ -1,121 +0,0 @@
-model: sond
-model_conf:
- lsm_weight: 0.0
- length_normalized_loss: true
- max_spk_num: 16
-
-# speech encoder
-encoder: ecapa_tdnn
-encoder_conf:
- # pass by model, equal to feature dim
- # input_size: 80
- pool_size: 20
- stride: 1
-speaker_encoder: conv
-speaker_encoder_conf:
- input_units: 256
- num_layers: 3
- num_units: 256
- kernel_size: 1
- dropout_rate: 0.0
- position_encoder: null
- out_units: 256
- out_norm: false
- auxiliary_states: false
- tf2torch_tensor_name_prefix_torch: speaker_encoder
- tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
-ci_scorer: dot
-ci_scorer_conf: {}
-cd_scorer: san
-cd_scorer_conf:
- input_size: 512
- output_size: 512
- out_units: 1
- attention_heads: 4
- linear_units: 1024
- num_blocks: 4
- dropout_rate: 0.0
- positional_dropout_rate: 0.0
- attention_dropout_rate: 0.0
- # use string "null" to remove input layer
- input_layer: "null"
- pos_enc_class: null
- normalize_before: true
- tf2torch_tensor_name_prefix_torch: cd_scorer
- tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
-# post net
-decoder: fsmn
-decoder_conf:
- in_units: 32
- out_units: 2517
- filter_size: 31
- fsmn_num_layers: 6
- dnn_num_layers: 1
- num_memory_units: 512
- ffn_inner_dim: 512
- dropout_rate: 0.0
- tf2torch_tensor_name_prefix_torch: decoder
- tf2torch_tensor_name_prefix_tf: EAND/post_net
-frontend: wav_frontend
-frontend_conf:
- fs: 16000
- window: povey
- n_mels: 80
- frame_length: 25
- frame_shift: 10
- filter_length_min: -1
- filter_length_max: -1
- lfr_m: 1
- lfr_n: 1
- dither: 0.0
- snip_edges: false
-
-# minibatch related
-batch_type: length
-# 16s * 16k * 16 samples
-batch_bins: 4096000
-num_workers: 8
-
-# optimization related
-accum_grad: 1
-grad_clip: 5
-max_epoch: 50
-val_scheduler_criterion:
- - valid
- - acc
-best_model_criterion:
-- - valid
- - der
- - min
-- - valid
- - forward_steps
- - max
-keep_nbest_models: 10
-
-optim: adam
-optim_conf:
- lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
- warmup_steps: 10000
-
-# without spec aug
-specaug: null
-specaug_conf:
- apply_time_warp: true
- time_warp_window: 5
- time_warp_mode: bicubic
- apply_freq_mask: true
- freq_mask_width_range:
- - 0
- - 30
- num_freq_mask: 2
- apply_time_mask: true
- time_mask_width_range:
- - 0
- - 40
- num_time_mask: 2
-
-log_interval: 50
-# without normalize
-normalize: None
diff --git a/egs/mars/sd/local_run.sh b/egs/mars/sd/local_run.sh
deleted file mode 100755
index 4516e9f..0000000
--- a/egs/mars/sd/local_run.sh
+++ /dev/null
@@ -1,171 +0,0 @@
-#!/usr/bin/env bash
-
-. ./path.sh || exit 1;
-
-# machines configuration
-CUDA_VISIBLE_DEVICES="6,7"
-gpu_num=2
-count=1
-gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
-# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
-train_cmd=utils/run.pl
-infer_cmd=utils/run.pl
-
-# general configuration
-feats_dir="." #feature output dictionary
-exp_dir="."
-lang=zh
-dumpdir=dump/raw
-feats_type=raw
-token_type=char
-scp=wav.scp
-type=kaldi_ark
-stage=3
-stop_stage=4
-
-# feature configuration
-feats_dim=
-sample_frequency=16000
-nj=32
-speed_perturb=
-
-# exp tag
-tag="exp1"
-
-. utils/parse_options.sh || exit 1;
-
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-train_set=train
-valid_set=dev
-test_sets="dev test"
-
-asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
-
-inference_config=conf/decode_asr_transformer.yaml
-inference_asr_model=valid.acc.ave_10best.pb
-
-# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
-ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
-
-if ${gpu_inference}; then
- inference_nj=$[${ngpu}*${njob}]
- _ngpu=1
-else
- inference_nj=$njob
- _ngpu=0
-fi
-
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
-
-# Training Stage
-world_size=$gpu_num # run on one machine
-if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
- mkdir -p ${exp_dir}/exp/${model_dir}
- mkdir -p ${exp_dir}/exp/${model_dir}/log
- INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
- if [ -f $INIT_FILE ];then
- rm -f $INIT_FILE
- fi
- init_method=file://$(readlink -f $INIT_FILE)
- echo "$0: init method is $init_method"
- for ((i = 0; i < $gpu_num; ++i)); do
- {
- rank=$i
- local_rank=$i
- gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
- --gpu_id $gpu_id \
- --use_preprocessor true \
- --token_type char \
- --token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
- --resume true \
- --output_dir ${exp_dir}/exp/${model_dir} \
- --config $asr_config \
- --input_size $feats_dim \
- --ngpu $gpu_num \
- --num_worker_count $count \
- --multiprocessing_distributed true \
- --dist_init_method $init_method \
- --dist_world_size $world_size \
- --dist_rank $rank \
- --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
- } &
- done
- wait
-fi
-
-# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
- for dset in ${test_sets}; do
- asr_exp=${exp_dir}/exp/${model_dir}
- inference_tag="$(basename "${inference_config}" .yaml)"
- _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
- _logdir="${_dir}/logdir"
- if [ -d ${_dir} ]; then
- echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
- exit 0
- fi
- mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
- key_file=${_data}/${scp}
- num_scp_file="$(<${key_file} wc -l)"
- _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
- split_scps=
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/keys.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
- _opts=
- if [ -n "${inference_config}" ]; then
- _opts+="--config ${inference_config} "
- fi
- ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
- python -m funasr.bin.asr_inference_launch \
- --batch_size 1 \
- --ngpu "${_ngpu}" \
- --njob ${njob} \
- --gpuid_list ${gpuid_list} \
- --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
- --key_file "${_logdir}"/keys.JOB.scp \
- --asr_train_config "${asr_exp}"/config.yaml \
- --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
- --output_dir "${_logdir}"/output.JOB \
- --mode asr \
- ${_opts}
-
- for f in token token_int score text; do
- if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
- for i in $(seq "${_nj}"); do
- cat "${_logdir}/output.${i}/1best_recog/${f}"
- done | sort -k1 >"${_dir}/${f}"
- fi
- done
- python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
- python utils/proce_text.py ${_data}/text ${_data}/text.proc
- python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
- tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
- cat ${_dir}/text.cer.txt
- done
-fi
-
diff --git a/egs/mars/sd/path.sh b/egs/mars/sd/path.sh
deleted file mode 100755
index 7972642..0000000
--- a/egs/mars/sd/path.sh
+++ /dev/null
@@ -1,5 +0,0 @@
-export FUNASR_DIR=$PWD/../../..
-
-# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
-export PYTHONIOENCODING=UTF-8
-export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/mars/sd/scripts/calculate_shapes.py b/egs/mars/sd/scripts/calculate_shapes.py
deleted file mode 100644
index b207f2d..0000000
--- a/egs/mars/sd/scripts/calculate_shapes.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--input_scp", type=str, required=True)
- parser.add_argument("--out_path")
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.out_path)):
- os.makedirs(os.path.dirname(args.out_path))
-
- task_list = load_scp_as_list(args.input_scp)
- return task_list, None, args
-
- def post(self, result_list, args):
- fd = open(args.out_path, "wt", encoding="utf-8")
- for results in result_list:
- for uttid, shape in results:
- fd.write("{} {}\n".format(uttid, ",".join(shape)))
- fd.close()
-
-
-def process(task_args):
- task_idx, task_list, _, args = task_args
- rst = []
- for uttid, file_path in task_list:
- data = kaldiio.load_mat(file_path)
- shape = [str(x) for x in data.shape]
- rst.append((uttid, shape))
- return rst
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py
deleted file mode 100644
index ec1c765..0000000
--- a/egs/mars/sd/scripts/dump_rttm_to_labels.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--rttm_list", type=str, required=True)
- parser.add_argument("--wav_scp_list", type=str, required=True)
- parser.add_argument("--out_dir", type=str, required=True)
- parser.add_argument("--n_spk", type=int, default=8)
- parser.add_argument("--remove_sil", default=False, action="store_true")
- parser.add_argument("--max_overlap", default=0, type=int)
- parser.add_argument("--frame_shift", type=float, default=0.01)
- args = parser.parse_args()
-
- rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
- meeting2rttm = OrderedDict()
- for rttm_path in rttm_list:
- meeting2rttm.update(self.load_rttm(rttm_path))
-
- wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
- meeting_scp = OrderedDict()
- for scp_path in wav_scp_list:
- meeting_scp.update(load_scp_as_dict(scp_path))
-
- if len(meeting_scp) != len(meeting2rttm):
- logging.warning("Number of wav and rttm mismatch {} != {}".format(
- len(meeting_scp), len(meeting2rttm)))
- common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
- logging.warning("Keep {} records.".format(len(common_keys)))
- new_meeting_scp = OrderedDict()
- rm_keys = []
- for key in meeting_scp:
- if key not in common_keys:
- rm_keys.append(key)
- else:
- new_meeting_scp[key] = meeting_scp[key]
- logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
-
- new_meeting2rttm = OrderedDict()
- rm_keys = []
- for key in meeting2rttm:
- if key not in common_keys:
- rm_keys.append(key)
- else:
- new_meeting2rttm[key] = meeting2rttm[key]
- logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
- meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
- return task_list, None, args
-
- @staticmethod
- def load_rttm(rttm_path):
- meeting2rttm = OrderedDict()
- for one_line in open(rttm_path, "rt", encoding="utf-8"):
- mid = one_line.strip().split(" ")[1]
- if mid not in meeting2rttm:
- meeting2rttm[mid] = []
- meeting2rttm[mid].append(one_line.strip())
-
- return meeting2rttm
-
- def post(self, results_list, args):
- pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
- sr=None, frame_shift=0.01):
- frame_shift = int(frame_shift * sr)
- num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
- multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
- for _, st, dur, spk in spk_turns:
- idx = spk_list.index(spk)
-
- st, dur = int(st * sr), int(dur * sr)
- frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
- frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
- multi_label[idx, frame_st:frame_ed] = 1
-
- if remove_sil:
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count)[0]
- multi_label = multi_label[:, idx]
-
- if max_overlap > 0:
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count <= max_overlap)[0]
- multi_label = multi_label[:, idx]
-
- label = multi_label.T
- return label # (T, N)
-
-
-def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
- sr=16000, frame_shift=0.01):
- wav, sr = soundfile.read(wav_path)
- wav_len = len(wav)
- spk_turns = []
- spk_list = []
- for one_line in rttms:
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if spk not in spk_list:
- spk_list.append(spk)
- spk_turns.append((mid, st, dur, spk))
- labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
- return labels, spk_list
-
-
-def process(task_args):
- task_idx, task_list, _, args = task_args
- spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
- "wt", encoding="utf-8")
- out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
- label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
- for mid, wav_path, rttms in task_list:
- meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
- args.sr, args.frame_shift)
- label_writer(mid, meeting_labels)
- spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
-
- spk_list_writer.close()
- label_writer.close()
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
deleted file mode 100644
index cd1ec7b..0000000
--- a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import numpy as np
-import os
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import soundfile as sf
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
- def prepare(self, parser):
- assert isinstance(parser, argparse.ArgumentParser)
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("rttm", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--min_dur", type=float, default=2.0)
- parser.add_argument("--max_spk_num", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- meeting2rttms = {}
- for one_line in open(args.rttm, "rt"):
- parts = [x for x in one_line.strip().split(" ") if x != ""]
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if mid not in meeting2rttms:
- meeting2rttms[mid] = []
- meeting2rttms[mid].append(one_line)
-
- task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
- return task_list, None, args
-
- def post(self, result_list, args):
- count = [0, 0]
- for result in result_list:
- count[0] += result[0]
- count[1] += result[1]
- print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
- labels = np.zeros([max_spk_num, length], int)
- spk_list = []
- for one_line in rttms:
- parts = [x for x in one_line.strip().split(" ") if x != ""]
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
- if spk_name.isdigit():
- spk_name = "{}_S{:03d}".format(mid, int(spk_name))
- else:
- spk_name = "{}_{}".format(mid, spk_name)
- if spk_name not in spk_list:
- spk_list.append(spk_name)
- st, dur = int(st*sr), int(dur*sr)
- idx = spk_list.index(spk_name)
- labels[idx, st:st+dur] = 1
- return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
- turns = []
- label = np.sum(multi_label, axis=0) == 1
- spk, in_turn, st = None, False, 0
- for i in range(len(label)):
- if not in_turn and label[i]:
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- if not label[i]:
- in_turn = False
- turns.append([st, i, spk])
- elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
- turns.append([st, i, spk])
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- turns.append([st, len(label), spk])
- return turns
-
-
-def process(task_args):
- task_id, task_list, _, args = task_args
- spk_count = [0, 0]
- for mid, wav_path, rttms in task_list:
- wav, sr = sf.read(wav_path, dtype="int16")
- assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
- multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
- turns = get_nonoverlap_turns(multi_label, spk_list)
- extracted_spk = []
- count = 1
- for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
- if (ed - st) >= args.min_dur * args.sr:
- seg = wav[st: ed]
- save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
- if not os.path.exists(os.path.dirname(save_path)):
- os.makedirs(os.path.dirname(save_path))
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- count += 1
- if spk not in extracted_spk:
- extracted_spk.append(spk)
- if len(extracted_spk) != len(spk_list):
- print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
- mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
- ))
- spk_count[0] += len(extracted_spk)
- spk_count[1] += len(spk_list)
- return spk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py b/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
deleted file mode 100644
index e579f51..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("dir", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--n_spk", type=int, default=4)
- parser.add_argument("--remove_sil", default=False, action="store_true")
- args = parser.parse_args()
-
- meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp"))
- rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp]
- return task_list, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000):
- multi_label = np.zeros([n_spk, length], dtype=int)
- for _, st, dur, spk in spk_turns:
- st, dur = int(st * sr), int(dur * sr)
- idx = spk_list.index(spk)
- multi_label[idx, st:st+dur] = 1
- if not remove_sil:
- return multi_label.T
-
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count)[0]
- label = multi_label[:, idx].T
- return label # (T, N)
-
-
-def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000):
- wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
- spk_turns = []
- spk_list = []
- for one_line in open(rttm_path, "rt"):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
- spk = "{}_S{:03d}".format(mid, spk)
- if spk not in spk_list:
- spk_list.append(spk)
- spk_turns.append((mid, st, dur, spk))
- labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil)
- return labels
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- for mid, wav_path, rttm_path in task_list:
- meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil)
- save_path = os.path.join(args.out_dir, "{}.lbl".format(mid))
- np.save(save_path, meeting_labels.astype(bool))
- print(mid)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py b/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
deleted file mode 100644
index 11bc395..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-from tqdm import tqdm
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--chunk_dur", type=float, default=16)
- parser.add_argument("--shift_dur", type=float, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- return wav_scp, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr)
- for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
- if not os.path.exists(os.path.join(args.out_dir, mid)):
- os.makedirs(os.path.join(args.out_dir, mid))
-
- wav = librosa.load(wav_path, args.sr, True)[0] * 32767
- n_chunk = (len(wav) - chunk_len) // shift_len + 1
- if (len(wav) - chunk_len) % shift_len > 0:
- n_chunk += 1
- for i in range(n_chunk):
- seg = wav[i*shift_len: i*shift_len + chunk_len]
- st = int(float(i*shift_len)/args.sr * 100)
- dur = int(float(len(seg))/args.sr * 100)
- file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur)
- save_path = os.path.join(args.out_dir, mid, file_name)
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py b/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
deleted file mode 100644
index 011bd7c..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("--rttm_scp", type=str)
- parser.add_argument("--seg_file", type=str)
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.seg_file)):
- os.makedirs(os.path.dirname(args.seg_file))
-
- task_list = load_scp_as_list(args.rttm_scp)
- return task_list, None, args
-
- def post(self, results_list, args):
- with open(args.seg_file, "wt", encoding="utf-8") as fd:
- for results in results_list:
- fd.writelines(results)
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- outputs = []
- for mid, rttm_path in task_list:
- spk_turns = []
- length = 0
- for one_line in open(rttm_path, 'rt', encoding="utf-8"):
- parts = one_line.strip().split(" ")
- _, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- st, ed = int(st*100), int((st + dur)*100)
- length = ed if ed > length else length
- spk_turns.append([mid, st, ed, spk_name])
- is_sph = np.zeros((length+1, ), dtype=bool)
- for _, st, ed, _ in spk_turns:
- is_sph[st:ed] = True
-
- st, in_speech = 0, False
- for i in range(length+1):
- if not in_speech and is_sph[i]:
- st, in_speech = i, True
- if in_speech and not is_sph[i]:
- in_speech = False
- outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
- mid, st, i, mid, float(st)/100, float(i)/100
- ))
- return outputs
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py b/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
deleted file mode 100644
index a2bcd39..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import soundfile
-import kaldiio
-from tqdm import tqdm
-import json
-import os
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import numpy as np
-import argparse
-import random
-
-short_spk_list = []
-def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
- all_utts = spk2utt[spk]
- idx_list = list(range(len(all_utts)))
- random.shuffle(idx_list)
- count = 0
- utt_list = []
- for i in idx_list:
- utt_id = all_utts[i]
- utt_list.append(utt_id)
- count += int(utt2frames[utt_id])
- if count >= total_len:
- break
- if count < 300 and spk not in short_spk_list:
- print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
- short_spk_list.append(spk)
-
- ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list]
- ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list]
- ivc = np.concatenate(ivc_list, axis=0)
- ivc = np.mean(ivc, axis=0, keepdims=False)
- return ivc
-
-
-def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
- out_prefix = args.out
-
- ivc_dim = 192
- win_len, win_shift = 400, 160
- label_weights = 2 ** np.array(list(range(args.n_spk)))
- wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
- ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
- label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
-
-
- frames_list = []
- chunk_size = int(args.chunk_size * args.sr)
- chunk_shift = int(args.chunk_shift * args.sr)
- for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
- meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
- num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
- meeting_labels = np.load(labels_scp[mid])
- for i in range(num_chunk):
- st, ed = i*chunk_shift, i*chunk_shift+chunk_size
- seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
- wav_writer(seg_id, meeting_wav[st: ed])
-
- xvec_list = []
- for spk in meeting2spk_list[mid]:
- spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)
- xvec_list.append(spk_xvec)
- for _ in range(args.n_spk - len(xvec_list)):
- xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
- xvec = np.row_stack(xvec_list)
- ivc_writer(seg_id, xvec)
-
- wav_label = meeting_labels[st:ed, :]
- frame_num = (ed-st) // win_shift
- # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
- feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32)
- for i in range(frame_num):
- frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
- feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32)
- label_writer(seg_id, feat_label)
-
- frames_list.append((mid, feat_label.shape[0]))
- return frames_list
-
-
-def calc_spk_list(rttm_path):
- spk_list = []
- for one_line in open(rttm_path, "rt"):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
- spk = "{}_S{:03d}".format(mid, spk)
- if spk not in spk_list:
- spk_list.append(spk)
-
- return spk_list
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--dir", required=True, type=str, default=None,
- help="feats.scp")
- parser.add_argument("--out", required=True, type=str, default=None,
- help="The prefix of dumpped files.")
- parser.add_argument("--n_spk", type=int, default=4)
- parser.add_argument("--use_lfr", default=False, action="store_true")
- parser.add_argument("--no_pbar", default=False, action="store_true")
- parser.add_argument("--sr", type=int, default=16000)
- parser.add_argument("--chunk_size", type=int, default=16)
- parser.add_argument("--chunk_shift", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.out)):
- os.makedirs(os.path.dirname(args.out))
-
- meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
- labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
- rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
- utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
- utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
- utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
- utt2frames = {}
- for uttid, wav_path in utt2wav.items():
- wav, sr = soundfile.read(wav_path, dtype="int16")
- utt2frames[uttid] = int(len(wav) / sr * 100)
-
- meeting2spk_list = {}
- for mid, rttm_path in rttm_scp:
- meeting2spk_list[mid] = calc_spk_list(rttm_path)
-
- spk2utt = {}
- for utt, spk in utt2spk.items():
- if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
- if spk not in spk2utt:
- spk2utt[spk] = []
- spk2utt[spk].append(utt)
-
- # random.shuffle(feat_scp)
- meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
- total_frames = sum([x[1] for x in meeting_lens])
- print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
-
-
-if __name__ == '__main__':
- main()
diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
deleted file mode 100644
index 1d6f53e..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
+++ /dev/null
@@ -1,110 +0,0 @@
-from __future__ import print_function
-import numpy as np
-import os
-import sys
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import librosa
-import soundfile as sf
-from copy import deepcopy
-import json
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
- def prepare(self, parser):
- assert isinstance(parser, argparse.ArgumentParser)
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("rttm_scp", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--min_dur", type=float, default=2.0)
- parser.add_argument("--max_spk_num", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- rttm_scp = load_scp_as_dict(args.rttm_scp)
- task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp]
- return task_list, None, args
-
- def post(self, result_list, args):
- count = [0, 0]
- for result in result_list:
- count[0] += result[0]
- count[1] += result[1]
- print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4):
- labels = np.zeros([max_spk_num, length], int)
- spk_list = []
- for one_line in open(rttm_path, 'rt'):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if spk_name.isdigit():
- spk_name = "{}_S{:03d}".format(mid, int(spk_name))
- if spk_name not in spk_list:
- spk_list.append(spk_name)
- st, dur = int(st*sr), int(dur*sr)
- idx = spk_list.index(spk_name)
- labels[idx, st:st+dur] = 1
- return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
- turns = []
- label = np.sum(multi_label, axis=0) == 1
- spk, in_turn, st = None, False, 0
- for i in range(len(label)):
- if not in_turn and label[i]:
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- if not label[i]:
- in_turn = False
- turns.append([st, i, spk])
- elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
- turns.append([st, i, spk])
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- turns.append([st, len(label), spk])
- return turns
-
-
-def process(task_args):
- task_id, task_list, _, args = task_args
- spk_count = [0, 0]
- for mid, wav_path, rttm_path in task_list:
- wav, sr = sf.read(wav_path, dtype="int16")
- assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
- multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
- turns = get_nonoverlap_turns(multi_label, spk_list)
- extracted_spk = []
- count = 1
- for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
- if (ed - st) >= args.min_dur * args.sr:
- seg = wav[st: ed]
- save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
- if not os.path.exists(os.path.dirname(save_path)):
- os.makedirs(os.path.dirname(save_path))
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- count += 1
- if spk not in extracted_spk:
- extracted_spk.append(spk)
- if len(extracted_spk) != len(spk_list):
- print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
- mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
- ))
- spk_count[0] += len(extracted_spk)
- spk_count[1] += len(spk_list)
- return spk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py b/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
deleted file mode 100644
index 8b3195f..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("dir", type=str)
- parser.add_argument("out_dir", type=str)
- args = parser.parse_args()
-
- meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp"))
- vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8")
- meeting2vad = {}
- for one_line in vad_file:
- uid, mid, st, ed = one_line.strip().split(" ")
- st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
- if mid not in meeting2vad:
- meeting2vad[mid] = []
- meeting2vad[mid].append((uid, st, ed))
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp]
- return task_list, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- for mid, wav_path, vad_list in task_list:
- wav = librosa.load(wav_path, args.sr, True)[0] * 32767
- seg_list = []
- pos_map = []
- offset = 0
- for uid, st, ed in vad_list:
- seg_list.append(wav[st: ed])
- pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
- offset = offset + ed - st
- out = np.concatenate(seg_list, axis=0)
- save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
- sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
- with open(map_path, "wt", encoding="utf-8") as fd:
- fd.writelines(pos_map)
- print(mid)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py
deleted file mode 100644
index f61b808..0000000
--- a/egs/mars/sd/scripts/simu_chunk_with_labels.py
+++ /dev/null
@@ -1,261 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-import random
-from typing import List, Dict
-from copy import deepcopy
-import json
-logging.basicConfig(
- level="INFO",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--label_scp", type=str, required=True)
- parser.add_argument("--wav_scp", type=str, required=True)
- parser.add_argument("--utt2spk", type=str, required=True)
- parser.add_argument("--spk2meeting", type=str, required=True)
- parser.add_argument("--utt2xvec", type=str, required=True)
- parser.add_argument("--out_dir", type=str, required=True)
- parser.add_argument("--chunk_size", type=float, default=16)
- parser.add_argument("--chunk_shift", type=float, default=4)
- parser.add_argument("--frame_shift", type=float, default=0.01)
- parser.add_argument("--embedding_dim", type=int, default=None)
- parser.add_argument("--average_emb_num", type=int, default=0)
- parser.add_argument("--subset", type=int, default=0)
- parser.add_argument("--data_json", type=str, default=None)
- parser.add_argument("--seed", type=int, default=1234)
- parser.add_argument("--log_interval", type=int, default=100)
- args = parser.parse_args()
- random.seed(args.seed)
- np.random.seed(args.seed)
-
- logging.info("Loading data...")
- if not os.path.exists(args.data_json):
- label_list = load_scp_as_list(args.label_scp)
- wav_scp = load_scp_as_dict(args.wav_scp)
- utt2spk = load_scp_as_dict(args.utt2spk)
- utt2xvec = load_scp_as_dict(args.utt2xvec)
- spk2meeting = load_scp_as_dict(args.spk2meeting)
-
- meeting2spks = OrderedDict()
- for spk, meeting in spk2meeting.items():
- if meeting not in meeting2spks:
- meeting2spks[meeting] = []
- meeting2spks[meeting].append(spk)
-
- spk2utts = OrderedDict()
- for utt, spk in utt2spk.items():
- if spk not in spk2utts:
- spk2utts[spk] = []
- spk2utts[spk].append(utt)
-
- os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
- logging.info("Dump data...")
- json.dump({
- "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
- "spk2utts": spk2utts, "meeting2spks": meeting2spks
- }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
- else:
- data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
- label_list = data_dict["label_list"]
- wav_scp = data_dict["wav_scp"]
- utt2xvec = data_dict["utt2xvec"]
- spk2utts = data_dict["spk2utts"]
- meeting2spks = data_dict["meeting2spks"]
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- args.chunk_size = int(args.chunk_size / args.frame_shift)
- args.chunk_shift = int(args.chunk_shift / args.frame_shift)
-
- if args.embedding_dim is None:
- args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
- logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
-
- logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
- len(wav_scp), len(spk2utts), len(meeting2spks)
- ))
- return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
-
- def post(self, results_list, args):
- logging.info("[main]: Got {} chunks.".format(sum(results_list)))
-
-
-def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
- utt_list = spk2utts[spk]
- wav_list = []
- cur_length = 0
- while cur_length < sample_length:
- uttid = random.choice(utt_list)
- wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
- wav_list.append(wav)
- cur_length += len(wav)
- concat_wav = np.concatenate(wav_list, axis=0)
- start = random.randint(0, len(concat_wav) - sample_length)
- return concat_wav[start: start+sample_length]
-
-
-def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
- # process for dummy speaker
- if spk == "None":
- return np.zeros((1, embedding_dim), dtype=np.float32)
-
- # calculate averaged speaker embeddings
- utt_list = spk2utts[spk]
- if average_emb_num == 0 or average_emb_num > len(utt_list):
- xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
- else:
- xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
- xvec = np.concatenate(xvec_list, axis=0)
- xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
- xvec = np.mean(xvec, axis=0)
-
- return xvec
-
-
-def simu_chunk(
- frame_label: np.ndarray,
- sample_label: np.ndarray,
- wav_scp: Dict[str, str],
- utt2xvec: Dict[str, str],
- spk2utts: Dict[str, List[str]],
- meeting2spks: Dict[str, List[str]],
- all_speaker_list: List[str],
- meeting_list: List[str],
- embedding_dim: int,
- average_emb_num: int,
-):
- frame_length, max_spk_num = frame_label.shape
- sample_length = sample_label.shape[0]
- positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
- pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
-
- # get positive speakers
- if len(pos_speaker_list) >= positive_speaker_num:
- pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
- else:
- while len(pos_speaker_list) < positive_speaker_num:
- _spk = random.choice(all_speaker_list)
- if _spk not in pos_speaker_list:
- pos_speaker_list.append(_spk)
-
- # get negative speakers
- negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
- neg_speaker_list = []
- while len(neg_speaker_list) < negative_speaker_num:
- _spk = random.choice(all_speaker_list)
- if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
- neg_speaker_list.append(_spk)
- neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
-
- random.shuffle(pos_speaker_list)
- random.shuffle(neg_speaker_list)
- seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
- this_spk_list = []
- for idx, frame_num in enumerate(frame_label.sum(axis=0)):
- if frame_num > 0:
- spk = pos_speaker_list.pop(0)
- this_spk_list.append(spk)
- simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
- seperated_wav[:, idx] = simu_spk_wav
- else:
- spk = neg_speaker_list.pop(0)
- this_spk_list.append(spk)
-
- # calculate mixed wav
- mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
-
- # shuffle the order of speakers
- shuffle_idx = list(range(max_spk_num))
- random.shuffle(shuffle_idx)
- this_spk_list = [this_spk_list[x] for x in shuffle_idx]
- seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
- frame_label = frame_label.transpose()[shuffle_idx].transpose()
-
- # calculate profile
- profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
- for spk in this_spk_list]
- profile = np.vstack(profile)
- # pse_weights = 2 ** np.arange(max_spk_num)
- # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
- # pse_label = pse_label.astype(str).tolist()
-
- return mixed_wav, seperated_wav, profile, frame_label
-
-
-def process(task_args):
- task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
- logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
-
- out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
- wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
- # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
- profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
- label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
-
- labels_list = []
- total_chunks = 0
- for org_mid, label_path in task_list:
- whole_label = kaldiio.load_mat(label_path)
- # random offset to keep diversity
- rand_shift = random.randint(0, args.chunk_shift)
- num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
- labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
- total_chunks += num_chunk
-
- idx = 0
- simu_chunk_count = 0
- for org_mid, whole_label, rand_shift, num_chunk in labels_list:
- for i in range(num_chunk):
- idx = idx + 1
- st = i * args.chunk_shift + rand_shift
- ed = i * args.chunk_shift + args.chunk_size + rand_shift
- utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
- args.subset + 1, task_idx + 1, org_mid, st, ed
- )
- frame_label = whole_label[st: ed, :]
- sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
- mix_wav, seg_wav, profile, frame_label = simu_chunk(
- frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
- speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
- )
- wav_mix_writer(utt_id, mix_wav)
- # wav_sep_writer(utt_id, seg_wav)
- profile_writer(utt_id, profile)
- label_writer(utt_id, frame_label)
-
- simu_chunk_count += 1
- if simu_chunk_count % args.log_interval == 0:
- logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
- task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
- wav_mix_writer.close()
- # wav_sep_writer.close()
- profile_writer.close()
- label_writer.close()
- logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
- return simu_chunk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
--
Gitblit v1.9.1