Merge pull request #159 from alibaba-damo-academy/dev_dzh
Dev dzh
| New file |
| | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | import numpy as np |
| | | import os |
| | | |
| | | |
| | | def test_wav_cpu_infer(): |
| | | output_dir = "./outputs" |
| | | data_path_and_name_and_type = [ |
| | | "data/unit_test/test_wav.scp,speech,sound", |
| | | "data/unit_test/test_profile.scp,profile,kaldi_ark", |
| | | ] |
| | | diar_pipeline = pipeline( |
| | | task=Tasks.speaker_diarization, |
| | | model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch', |
| | | mode="sond", |
| | | output_dir=output_dir, |
| | | num_workers=0, |
| | | log_level="WARNING", |
| | | ) |
| | | results = diar_pipeline(data_path_and_name_and_type) |
| | | print(results) |
| | | |
| | | |
| | | def test_wav_gpu_infer(): |
| | | output_dir = "./outputs" |
| | | data_path_and_name_and_type = [ |
| | | "data/unit_test/test_wav.scp,speech,sound", |
| | | "data/unit_test/test_profile.scp,profile,kaldi_ark", |
| | | ] |
| | | diar_pipeline = pipeline( |
| | | task=Tasks.speaker_diarization, |
| | | model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch', |
| | | mode="sond", |
| | | output_dir=output_dir, |
| | | num_workers=0, |
| | | log_level="WARNING", |
| | | ) |
| | | results = diar_pipeline(data_path_and_name_and_type) |
| | | print(results) |
| | | |
| | | |
| | | def test_without_profile_gpu_infer(): |
| | | raw_inputs = [ |
| | | "data/unit_test/raw_inputs/record.wav", |
| | | "data/unit_test/raw_inputs/spk1.wav", |
| | | "data/unit_test/raw_inputs/spk2.wav", |
| | | "data/unit_test/raw_inputs/spk3.wav", |
| | | "data/unit_test/raw_inputs/spk4.wav" |
| | | ] |
| | | diar_pipeline = pipeline( |
| | | task=Tasks.speaker_diarization, |
| | | model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch', |
| | | mode="sond_demo", |
| | | sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch", |
| | | sv_model_revision="master", |
| | | num_workers=0, |
| | | log_level="WARNING", |
| | | param_dict={}, |
| | | ) |
| | | results = diar_pipeline(raw_inputs) |
| | | print(results) |
| | | |
| | | |
| | | def test_url_without_profile_gpu_infer(): |
| | | raw_inputs = [ |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk2.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk3.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk4.wav", |
| | | ] |
| | | diar_pipeline = pipeline( |
| | | task=Tasks.speaker_diarization, |
| | | model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch', |
| | | mode="sond_demo", |
| | | sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch", |
| | | sv_model_revision="master", |
| | | num_workers=0, |
| | | log_level="WARNING", |
| | | param_dict={}, |
| | | ) |
| | | results = diar_pipeline(raw_inputs) |
| | | print(results) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| | | test_wav_cpu_infer() |
| | | test_wav_gpu_infer() |
| | | test_without_profile_gpu_infer() |
| | | test_url_without_profile_gpu_infer() |
| New file |
| | |
| | | model: sond |
| | | model_conf: |
| | | lsm_weight: 0.0 |
| | | length_normalized_loss: true |
| | | max_spk_num: 16 |
| | | |
| | | # speech encoder |
| | | encoder: ecapa_tdnn |
| | | encoder_conf: |
| | | # pass by model, equal to feature dim |
| | | # input_size: 80 |
| | | pool_size: 20 |
| | | stride: 1 |
| | | speaker_encoder: conv |
| | | speaker_encoder_conf: |
| | | input_units: 256 |
| | | num_layers: 3 |
| | | num_units: 256 |
| | | kernel_size: 1 |
| | | dropout_rate: 0.0 |
| | | position_encoder: null |
| | | out_units: 256 |
| | | out_norm: false |
| | | auxiliary_states: false |
| | | tf2torch_tensor_name_prefix_torch: speaker_encoder |
| | | tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder |
| | | ci_scorer: dot |
| | | ci_scorer_conf: {} |
| | | cd_scorer: san |
| | | cd_scorer_conf: |
| | | input_size: 512 |
| | | output_size: 512 |
| | | out_units: 1 |
| | | attention_heads: 4 |
| | | linear_units: 1024 |
| | | num_blocks: 4 |
| | | dropout_rate: 0.0 |
| | | positional_dropout_rate: 0.0 |
| | | attention_dropout_rate: 0.0 |
| | | # use string "null" to remove input layer |
| | | input_layer: "null" |
| | | pos_enc_class: null |
| | | normalize_before: true |
| | | tf2torch_tensor_name_prefix_torch: cd_scorer |
| | | tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer |
| | | # post net |
| | | decoder: fsmn |
| | | decoder_conf: |
| | | in_units: 32 |
| | | out_units: 2517 |
| | | filter_size: 31 |
| | | fsmn_num_layers: 6 |
| | | dnn_num_layers: 1 |
| | | num_memory_units: 512 |
| | | ffn_inner_dim: 512 |
| | | dropout_rate: 0.0 |
| | | tf2torch_tensor_name_prefix_torch: decoder |
| | | tf2torch_tensor_name_prefix_tf: EAND/post_net |
| | | frontend: wav_frontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | window: povey |
| | | n_mels: 80 |
| | | frame_length: 25 |
| | | frame_shift: 10 |
| | | filter_length_min: -1 |
| | | filter_length_max: -1 |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | dither: 0.0 |
| | | snip_edges: false |
| | | |
| | | # minibatch related |
| | | batch_type: length |
| | | # 16s * 16k * 16 samples |
| | | batch_bins: 4096000 |
| | | num_workers: 8 |
| | | |
| | | # optimization related |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 50 |
| | | val_scheduler_criterion: |
| | | - valid |
| | | - acc |
| | | best_model_criterion: |
| | | - - valid |
| | | - der |
| | | - min |
| | | - - valid |
| | | - forward_steps |
| | | - max |
| | | keep_nbest_models: 10 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.001 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 10000 |
| | | |
| | | # without spec aug |
| | | specaug: null |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 30 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_range: |
| | | - 0 |
| | | - 40 |
| | | num_time_mask: 2 |
| | | |
| | | log_interval: 50 |
| | | # without normalize |
| | | normalize: None |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | . ./path.sh || exit 1; |
| | | |
| | | # machines configuration |
| | | CUDA_VISIBLE_DEVICES="6,7" |
| | | gpu_num=2 |
| | | count=1 |
| | | gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding |
| | | # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob |
| | | njob=5 |
| | | train_cmd=utils/run.pl |
| | | infer_cmd=utils/run.pl |
| | | |
| | | # general configuration |
| | | feats_dir="." #feature output dictionary |
| | | exp_dir="." |
| | | lang=zh |
| | | dumpdir=dump/raw |
| | | feats_type=raw |
| | | token_type=char |
| | | scp=wav.scp |
| | | type=kaldi_ark |
| | | stage=3 |
| | | stop_stage=4 |
| | | |
| | | # feature configuration |
| | | feats_dim= |
| | | sample_frequency=16000 |
| | | nj=32 |
| | | speed_perturb= |
| | | |
| | | # exp tag |
| | | tag="exp1" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | train_set=train |
| | | valid_set=dev |
| | | test_sets="dev test" |
| | | |
| | | asr_config=conf/train_asr_conformer.yaml |
| | | model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}" |
| | | |
| | | inference_config=conf/decode_asr_transformer.yaml |
| | | inference_asr_model=valid.acc.ave_10best.pth |
| | | |
| | | # you can set gpu num for decoding here |
| | | gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default |
| | | ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') |
| | | |
| | | if ${gpu_inference}; then |
| | | inference_nj=$[${ngpu}*${njob}] |
| | | _ngpu=1 |
| | | else |
| | | inference_nj=$njob |
| | | _ngpu=0 |
| | | fi |
| | | |
| | | feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir} |
| | | feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir} |
| | | feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir} |
| | | |
| | | # Training Stage |
| | | world_size=$gpu_num # run on one machine |
| | | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then |
| | | echo "stage 3: Training" |
| | | mkdir -p ${exp_dir}/exp/${model_dir} |
| | | mkdir -p ${exp_dir}/exp/${model_dir}/log |
| | | INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init |
| | | if [ -f $INIT_FILE ];then |
| | | rm -f $INIT_FILE |
| | | fi |
| | | init_method=file://$(readlink -f $INIT_FILE) |
| | | echo "$0: init method is $init_method" |
| | | for ((i = 0; i < $gpu_num; ++i)); do |
| | | { |
| | | rank=$i |
| | | local_rank=$i |
| | | gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) |
| | | asr_train.py \ |
| | | --gpu_id $gpu_id \ |
| | | --use_preprocessor true \ |
| | | --token_type char \ |
| | | --token_list $token_list \ |
| | | --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \ |
| | | --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \ |
| | | --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \ |
| | | --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \ |
| | | --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \ |
| | | --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \ |
| | | --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \ |
| | | --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \ |
| | | --resume true \ |
| | | --output_dir ${exp_dir}/exp/${model_dir} \ |
| | | --config $asr_config \ |
| | | --input_size $feats_dim \ |
| | | --ngpu $gpu_num \ |
| | | --num_worker_count $count \ |
| | | --multiprocessing_distributed true \ |
| | | --dist_init_method $init_method \ |
| | | --dist_world_size $world_size \ |
| | | --dist_rank $rank \ |
| | | --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 |
| | | } & |
| | | done |
| | | wait |
| | | fi |
| | | |
| | | # Testing Stage |
| | | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then |
| | | echo "stage 4: Inference" |
| | | for dset in ${test_sets}; do |
| | | asr_exp=${exp_dir}/exp/${model_dir} |
| | | inference_tag="$(basename "${inference_config}" .yaml)" |
| | | _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" |
| | | _logdir="${_dir}/logdir" |
| | | if [ -d ${_dir} ]; then |
| | | echo "${_dir} is already exists. if you want to decode again, please delete this dir first." |
| | | exit 0 |
| | | fi |
| | | mkdir -p "${_logdir}" |
| | | _data="${feats_dir}/${dumpdir}/${dset}" |
| | | key_file=${_data}/${scp} |
| | | num_scp_file="$(<${key_file} wc -l)" |
| | | _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") |
| | | split_scps= |
| | | for n in $(seq "${_nj}"); do |
| | | split_scps+=" ${_logdir}/keys.${n}.scp" |
| | | done |
| | | # shellcheck disable=SC2086 |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | _opts= |
| | | if [ -n "${inference_config}" ]; then |
| | | _opts+="--config ${inference_config} " |
| | | fi |
| | | ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob} \ |
| | | --gpuid_list ${gpuid_list} \ |
| | | --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ |
| | | --key_file "${_logdir}"/keys.JOB.scp \ |
| | | --asr_train_config "${asr_exp}"/config.yaml \ |
| | | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ |
| | | --output_dir "${_logdir}"/output.JOB \ |
| | | --mode asr \ |
| | | ${_opts} |
| | | |
| | | for f in token token_int score text; do |
| | | if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then |
| | | for i in $(seq "${_nj}"); do |
| | | cat "${_logdir}/output.${i}/1best_recog/${f}" |
| | | done | sort -k1 >"${_dir}/${f}" |
| | | fi |
| | | done |
| | | python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | cat ${_dir}/text.cer.txt |
| | | done |
| | | fi |
| | | |
| New file |
| | |
| | | export FUNASR_DIR=$PWD/../../.. |
| | | |
| | | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C |
| | | export PYTHONIOENCODING=UTF-8 |
| | | export PATH=$FUNASR_DIR/funasr/bin:$PATH |
| New file |
| | |
| | | import logging |
| | | import numpy as np |
| | | import soundfile |
| | | import kaldiio |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import os |
| | | import argparse |
| | | from collections import OrderedDict |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | |
| | | def prepare(self, parser: argparse.ArgumentParser): |
| | | parser.add_argument("--input_scp", type=str, required=True) |
| | | parser.add_argument("--out_path") |
| | | args = parser.parse_args() |
| | | |
| | | if not os.path.exists(os.path.dirname(args.out_path)): |
| | | os.makedirs(os.path.dirname(args.out_path)) |
| | | |
| | | task_list = load_scp_as_list(args.input_scp) |
| | | return task_list, None, args |
| | | |
| | | def post(self, result_list, args): |
| | | fd = open(args.out_path, "wt", encoding="utf-8") |
| | | for results in result_list: |
| | | for uttid, shape in results: |
| | | fd.write("{} {}\n".format(uttid, ",".join(shape))) |
| | | fd.close() |
| | | |
| | | |
| | | def process(task_args): |
| | | task_idx, task_list, _, args = task_args |
| | | rst = [] |
| | | for uttid, file_path in task_list: |
| | | data = kaldiio.load_mat(file_path) |
| | | shape = [str(x) for x in data.shape] |
| | | rst.append((uttid, shape)) |
| | | return rst |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import logging |
| | | import numpy as np |
| | | import soundfile |
| | | import kaldiio |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import os |
| | | import argparse |
| | | from collections import OrderedDict |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | |
| | | def prepare(self, parser: argparse.ArgumentParser): |
| | | parser.add_argument("--rttm_list", type=str, required=True) |
| | | parser.add_argument("--wav_scp_list", type=str, required=True) |
| | | parser.add_argument("--out_dir", type=str, required=True) |
| | | parser.add_argument("--n_spk", type=int, default=8) |
| | | parser.add_argument("--remove_sil", default=False, action="store_true") |
| | | parser.add_argument("--max_overlap", default=0, type=int) |
| | | parser.add_argument("--frame_shift", type=float, default=0.01) |
| | | args = parser.parse_args() |
| | | |
| | | rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()] |
| | | meeting2rttm = OrderedDict() |
| | | for rttm_path in rttm_list: |
| | | meeting2rttm.update(self.load_rttm(rttm_path)) |
| | | |
| | | wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()] |
| | | meeting_scp = OrderedDict() |
| | | for scp_path in wav_scp_list: |
| | | meeting_scp.update(load_scp_as_dict(scp_path)) |
| | | |
| | | if len(meeting_scp) != len(meeting2rttm): |
| | | logging.warning("Number of wav and rttm mismatch {} != {}".format( |
| | | len(meeting_scp), len(meeting2rttm))) |
| | | common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys()) |
| | | logging.warning("Keep {} records.".format(len(common_keys))) |
| | | new_meeting_scp = OrderedDict() |
| | | rm_keys = [] |
| | | for key in meeting_scp: |
| | | if key not in common_keys: |
| | | rm_keys.append(key) |
| | | else: |
| | | new_meeting_scp[key] = meeting_scp[key] |
| | | logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys))) |
| | | |
| | | new_meeting2rttm = OrderedDict() |
| | | rm_keys = [] |
| | | for key in meeting2rttm: |
| | | if key not in common_keys: |
| | | rm_keys.append(key) |
| | | else: |
| | | new_meeting2rttm[key] = meeting2rttm[key] |
| | | logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys))) |
| | | meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm |
| | | if not os.path.exists(args.out_dir): |
| | | os.makedirs(args.out_dir) |
| | | |
| | | task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()] |
| | | return task_list, None, args |
| | | |
| | | @staticmethod |
| | | def load_rttm(rttm_path): |
| | | meeting2rttm = OrderedDict() |
| | | for one_line in open(rttm_path, "rt", encoding="utf-8"): |
| | | mid = one_line.strip().split(" ")[1] |
| | | if mid not in meeting2rttm: |
| | | meeting2rttm[mid] = [] |
| | | meeting2rttm[mid].append(one_line.strip()) |
| | | |
| | | return meeting2rttm |
| | | |
| | | def post(self, results_list, args): |
| | | pass |
| | | |
| | | |
| | | def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0, |
| | | sr=None, frame_shift=0.01): |
| | | frame_shift = int(frame_shift * sr) |
| | | num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift) |
| | | multi_label = np.zeros([n_spk, num_frame], dtype=np.float32) |
| | | for _, st, dur, spk in spk_turns: |
| | | idx = spk_list.index(spk) |
| | | |
| | | st, dur = int(st * sr), int(dur * sr) |
| | | frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift) |
| | | frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift) |
| | | multi_label[idx, frame_st:frame_ed] = 1 |
| | | |
| | | if remove_sil: |
| | | speech_count = np.sum(multi_label, axis=0) |
| | | idx = np.nonzero(speech_count)[0] |
| | | multi_label = multi_label[:, idx] |
| | | |
| | | if max_overlap > 0: |
| | | speech_count = np.sum(multi_label, axis=0) |
| | | idx = np.nonzero(speech_count <= max_overlap)[0] |
| | | multi_label = multi_label[:, idx] |
| | | |
| | | label = multi_label.T |
| | | return label # (T, N) |
| | | |
| | | |
| | | def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0, |
| | | sr=16000, frame_shift=0.01): |
| | | wav, sr = soundfile.read(wav_path) |
| | | wav_len = len(wav) |
| | | spk_turns = [] |
| | | spk_list = [] |
| | | for one_line in rttms: |
| | | parts = one_line.strip().split(" ") |
| | | mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7] |
| | | if spk not in spk_list: |
| | | spk_list.append(spk) |
| | | spk_turns.append((mid, st, dur, spk)) |
| | | labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift) |
| | | return labels, spk_list |
| | | |
| | | |
| | | def process(task_args): |
| | | task_idx, task_list, _, args = task_args |
| | | spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)), |
| | | "wt", encoding="utf-8") |
| | | out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1)) |
| | | label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) |
| | | for mid, wav_path, rttms in task_list: |
| | | meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap, |
| | | args.sr, args.frame_shift) |
| | | label_writer(mid, meeting_labels) |
| | | spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list))) |
| | | |
| | | spk_list_writer.close() |
| | | label_writer.close() |
| | | return None |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import numpy as np |
| | | import os |
| | | import argparse |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import soundfile as sf |
| | | from tqdm import tqdm |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | def prepare(self, parser): |
| | | assert isinstance(parser, argparse.ArgumentParser) |
| | | parser.add_argument("wav_scp", type=str) |
| | | parser.add_argument("rttm", type=str) |
| | | parser.add_argument("out_dir", type=str) |
| | | parser.add_argument("--min_dur", type=float, default=2.0) |
| | | parser.add_argument("--max_spk_num", type=int, default=4) |
| | | args = parser.parse_args() |
| | | |
| | | if not os.path.exists(args.out_dir): |
| | | os.makedirs(args.out_dir) |
| | | |
| | | wav_scp = load_scp_as_list(args.wav_scp) |
| | | meeting2rttms = {} |
| | | for one_line in open(args.rttm, "rt"): |
| | | parts = [x for x in one_line.strip().split(" ") if x != ""] |
| | | mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] |
| | | if mid not in meeting2rttms: |
| | | meeting2rttms[mid] = [] |
| | | meeting2rttms[mid].append(one_line) |
| | | |
| | | task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp] |
| | | return task_list, None, args |
| | | |
| | | def post(self, result_list, args): |
| | | count = [0, 0] |
| | | for result in result_list: |
| | | count[0] += result[0] |
| | | count[1] += result[1] |
| | | print("Found {} speakers, extracted {}.".format(count[1], count[0])) |
| | | |
| | | |
| | | # SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA> |
| | | def calc_multi_label(rttms, length, sr=8000, max_spk_num=4): |
| | | labels = np.zeros([max_spk_num, length], int) |
| | | spk_list = [] |
| | | for one_line in rttms: |
| | | parts = [x for x in one_line.strip().split(" ") if x != ""] |
| | | mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] |
| | | spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "") |
| | | if spk_name.isdigit(): |
| | | spk_name = "{}_S{:03d}".format(mid, int(spk_name)) |
| | | else: |
| | | spk_name = "{}_{}".format(mid, spk_name) |
| | | if spk_name not in spk_list: |
| | | spk_list.append(spk_name) |
| | | st, dur = int(st*sr), int(dur*sr) |
| | | idx = spk_list.index(spk_name) |
| | | labels[idx, st:st+dur] = 1 |
| | | return labels, spk_list |
| | | |
| | | |
| | | def get_nonoverlap_turns(multi_label, spk_list): |
| | | turns = [] |
| | | label = np.sum(multi_label, axis=0) == 1 |
| | | spk, in_turn, st = None, False, 0 |
| | | for i in range(len(label)): |
| | | if not in_turn and label[i]: |
| | | st, in_turn = i, True |
| | | spk = spk_list[np.argmax(multi_label[:, i], axis=0)] |
| | | if in_turn: |
| | | if not label[i]: |
| | | in_turn = False |
| | | turns.append([st, i, spk]) |
| | | elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]: |
| | | turns.append([st, i, spk]) |
| | | st, in_turn = i, True |
| | | spk = spk_list[np.argmax(multi_label[:, i], axis=0)] |
| | | if in_turn: |
| | | turns.append([st, len(label), spk]) |
| | | return turns |
| | | |
| | | |
| | | def process(task_args): |
| | | task_id, task_list, _, args = task_args |
| | | spk_count = [0, 0] |
| | | for mid, wav_path, rttms in task_list: |
| | | wav, sr = sf.read(wav_path, dtype="int16") |
| | | assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr) |
| | | multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num) |
| | | turns = get_nonoverlap_turns(multi_label, spk_list) |
| | | extracted_spk = [] |
| | | count = 1 |
| | | for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar): |
| | | if (ed - st) >= args.min_dur * args.sr: |
| | | seg = wav[st: ed] |
| | | save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count)) |
| | | if not os.path.exists(os.path.dirname(save_path)): |
| | | os.makedirs(os.path.dirname(save_path)) |
| | | sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) |
| | | count += 1 |
| | | if spk not in extracted_spk: |
| | | extracted_spk.append(spk) |
| | | if len(extracted_spk) != len(spk_list): |
| | | print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format( |
| | | mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk]) |
| | | )) |
| | | spk_count[0] += len(extracted_spk) |
| | | spk_count[1] += len(spk_list) |
| | | return spk_count |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import numpy as np |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import os |
| | | import librosa |
| | | import argparse |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | |
| | | def prepare(self, parser): |
| | | parser.add_argument("dir", type=str) |
| | | parser.add_argument("out_dir", type=str) |
| | | parser.add_argument("--n_spk", type=int, default=4) |
| | | parser.add_argument("--remove_sil", default=False, action="store_true") |
| | | args = parser.parse_args() |
| | | |
| | | meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp")) |
| | | rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp")) |
| | | |
| | | if not os.path.exists(args.out_dir): |
| | | os.makedirs(args.out_dir) |
| | | |
| | | task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp] |
| | | return task_list, None, args |
| | | |
| | | def post(self, results_list, args): |
| | | pass |
| | | |
| | | |
| | | def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000): |
| | | multi_label = np.zeros([n_spk, length], dtype=int) |
| | | for _, st, dur, spk in spk_turns: |
| | | st, dur = int(st * sr), int(dur * sr) |
| | | idx = spk_list.index(spk) |
| | | multi_label[idx, st:st+dur] = 1 |
| | | if not remove_sil: |
| | | return multi_label.T |
| | | |
| | | speech_count = np.sum(multi_label, axis=0) |
| | | idx = np.nonzero(speech_count)[0] |
| | | label = multi_label[:, idx].T |
| | | return label # (T, N) |
| | | |
| | | |
| | | def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000): |
| | | wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr) |
| | | spk_turns = [] |
| | | spk_list = [] |
| | | for one_line in open(rttm_path, "rt"): |
| | | parts = one_line.strip().split(" ") |
| | | mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7]) |
| | | spk = "{}_S{:03d}".format(mid, spk) |
| | | if spk not in spk_list: |
| | | spk_list.append(spk) |
| | | spk_turns.append((mid, st, dur, spk)) |
| | | labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil) |
| | | return labels |
| | | |
| | | |
| | | def process(task_args): |
| | | _, task_list, _, args = task_args |
| | | for mid, wav_path, rttm_path in task_list: |
| | | meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil) |
| | | save_path = os.path.join(args.out_dir, "{}.lbl".format(mid)) |
| | | np.save(save_path, meeting_labels.astype(bool)) |
| | | print(mid) |
| | | return None |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import numpy as np |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import os |
| | | import librosa |
| | | import soundfile as sf |
| | | from tqdm import tqdm |
| | | import argparse |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | |
| | | def prepare(self, parser): |
| | | parser.add_argument("wav_scp", type=str) |
| | | parser.add_argument("out_dir", type=str) |
| | | parser.add_argument("--chunk_dur", type=float, default=16) |
| | | parser.add_argument("--shift_dur", type=float, default=4) |
| | | args = parser.parse_args() |
| | | |
| | | if not os.path.exists(args.out_dir): |
| | | os.makedirs(args.out_dir) |
| | | |
| | | wav_scp = load_scp_as_list(args.wav_scp) |
| | | return wav_scp, None, args |
| | | |
| | | def post(self, results_list, args): |
| | | pass |
| | | |
| | | |
| | | def process(task_args): |
| | | _, task_list, _, args = task_args |
| | | chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr) |
| | | for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar): |
| | | if not os.path.exists(os.path.join(args.out_dir, mid)): |
| | | os.makedirs(os.path.join(args.out_dir, mid)) |
| | | |
| | | wav = librosa.load(wav_path, args.sr, True)[0] * 32767 |
| | | n_chunk = (len(wav) - chunk_len) // shift_len + 1 |
| | | if (len(wav) - chunk_len) % shift_len > 0: |
| | | n_chunk += 1 |
| | | for i in range(n_chunk): |
| | | seg = wav[i*shift_len: i*shift_len + chunk_len] |
| | | st = int(float(i*shift_len)/args.sr * 100) |
| | | dur = int(float(len(seg))/args.sr * 100) |
| | | file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur) |
| | | save_path = os.path.join(args.out_dir, mid, file_name) |
| | | sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) |
| | | return None |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import numpy as np |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import os |
| | | import argparse |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | |
| | | def prepare(self, parser): |
| | | parser.add_argument("--rttm_scp", type=str) |
| | | parser.add_argument("--seg_file", type=str) |
| | | args = parser.parse_args() |
| | | |
| | | if not os.path.exists(os.path.dirname(args.seg_file)): |
| | | os.makedirs(os.path.dirname(args.seg_file)) |
| | | |
| | | task_list = load_scp_as_list(args.rttm_scp) |
| | | return task_list, None, args |
| | | |
| | | def post(self, results_list, args): |
| | | with open(args.seg_file, "wt", encoding="utf-8") as fd: |
| | | for results in results_list: |
| | | fd.writelines(results) |
| | | |
| | | |
| | | def process(task_args): |
| | | _, task_list, _, args = task_args |
| | | outputs = [] |
| | | for mid, rttm_path in task_list: |
| | | spk_turns = [] |
| | | length = 0 |
| | | for one_line in open(rttm_path, 'rt', encoding="utf-8"): |
| | | parts = one_line.strip().split(" ") |
| | | _, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] |
| | | st, ed = int(st*100), int((st + dur)*100) |
| | | length = ed if ed > length else length |
| | | spk_turns.append([mid, st, ed, spk_name]) |
| | | is_sph = np.zeros((length+1, ), dtype=bool) |
| | | for _, st, ed, _ in spk_turns: |
| | | is_sph[st:ed] = True |
| | | |
| | | st, in_speech = 0, False |
| | | for i in range(length+1): |
| | | if not in_speech and is_sph[i]: |
| | | st, in_speech = i, True |
| | | if in_speech and not is_sph[i]: |
| | | in_speech = False |
| | | outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format( |
| | | mid, st, i, mid, float(st)/100, float(i)/100 |
| | | )) |
| | | return outputs |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import soundfile |
| | | import kaldiio |
| | | from tqdm import tqdm |
| | | import json |
| | | import os |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import numpy as np |
| | | import argparse |
| | | import random |
| | | |
| | | short_spk_list = [] |
| | | def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000): |
| | | all_utts = spk2utt[spk] |
| | | idx_list = list(range(len(all_utts))) |
| | | random.shuffle(idx_list) |
| | | count = 0 |
| | | utt_list = [] |
| | | for i in idx_list: |
| | | utt_id = all_utts[i] |
| | | utt_list.append(utt_id) |
| | | count += int(utt2frames[utt_id]) |
| | | if count >= total_len: |
| | | break |
| | | if count < 300 and spk not in short_spk_list: |
| | | print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300)) |
| | | short_spk_list.append(spk) |
| | | |
| | | ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list] |
| | | ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list] |
| | | ivc = np.concatenate(ivc_list, axis=0) |
| | | ivc = np.mean(ivc, axis=0, keepdims=False) |
| | | return ivc |
| | | |
| | | |
| | | def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args): |
| | | out_prefix = args.out |
| | | |
| | | ivc_dim = 192 |
| | | win_len, win_shift = 400, 160 |
| | | label_weights = 2 ** np.array(list(range(args.n_spk))) |
| | | wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix)) |
| | | ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix)) |
| | | label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix)) |
| | | |
| | | |
| | | frames_list = [] |
| | | chunk_size = int(args.chunk_size * args.sr) |
| | | chunk_shift = int(args.chunk_shift * args.sr) |
| | | for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar): |
| | | meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32') |
| | | num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1 |
| | | meeting_labels = np.load(labels_scp[mid]) |
| | | for i in range(num_chunk): |
| | | st, ed = i*chunk_shift, i*chunk_shift+chunk_size |
| | | seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100)) |
| | | wav_writer(seg_id, meeting_wav[st: ed]) |
| | | |
| | | xvec_list = [] |
| | | for spk in meeting2spk_list[mid]: |
| | | spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000) |
| | | xvec_list.append(spk_xvec) |
| | | for _ in range(args.n_spk - len(xvec_list)): |
| | | xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32)) |
| | | xvec = np.row_stack(xvec_list) |
| | | ivc_writer(seg_id, xvec) |
| | | |
| | | wav_label = meeting_labels[st:ed, :] |
| | | frame_num = (ed-st) // win_shift |
| | | # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant") |
| | | feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32) |
| | | for i in range(frame_num): |
| | | frame_label = wav_label[i*win_shift: (i+1)*win_shift, :] |
| | | feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32) |
| | | label_writer(seg_id, feat_label) |
| | | |
| | | frames_list.append((mid, feat_label.shape[0])) |
| | | return frames_list |
| | | |
| | | |
| | | def calc_spk_list(rttm_path): |
| | | spk_list = [] |
| | | for one_line in open(rttm_path, "rt"): |
| | | parts = one_line.strip().split(" ") |
| | | mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7]) |
| | | spk = "{}_S{:03d}".format(mid, spk) |
| | | if spk not in spk_list: |
| | | spk_list.append(spk) |
| | | |
| | | return spk_list |
| | | |
| | | |
| | | def main(): |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument("--dir", required=True, type=str, default=None, |
| | | help="feats.scp") |
| | | parser.add_argument("--out", required=True, type=str, default=None, |
| | | help="The prefix of dumpped files.") |
| | | parser.add_argument("--n_spk", type=int, default=4) |
| | | parser.add_argument("--use_lfr", default=False, action="store_true") |
| | | parser.add_argument("--no_pbar", default=False, action="store_true") |
| | | parser.add_argument("--sr", type=int, default=16000) |
| | | parser.add_argument("--chunk_size", type=int, default=16) |
| | | parser.add_argument("--chunk_shift", type=int, default=4) |
| | | args = parser.parse_args() |
| | | |
| | | if not os.path.exists(os.path.dirname(args.out)): |
| | | os.makedirs(os.path.dirname(args.out)) |
| | | |
| | | meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp")) |
| | | labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp")) |
| | | rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp")) |
| | | utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk")) |
| | | utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec")) |
| | | utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp")) |
| | | utt2frames = {} |
| | | for uttid, wav_path in utt2wav.items(): |
| | | wav, sr = soundfile.read(wav_path, dtype="int16") |
| | | utt2frames[uttid] = int(len(wav) / sr * 100) |
| | | |
| | | meeting2spk_list = {} |
| | | for mid, rttm_path in rttm_scp: |
| | | meeting2spk_list[mid] = calc_spk_list(rttm_path) |
| | | |
| | | spk2utt = {} |
| | | for utt, spk in utt2spk.items(): |
| | | if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25: |
| | | if spk not in spk2utt: |
| | | spk2utt[spk] = [] |
| | | spk2utt[spk].append(utt) |
| | | |
| | | # random.shuffle(feat_scp) |
| | | meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args) |
| | | total_frames = sum([x[1] for x in meeting_lens]) |
| | | print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames)) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | main() |
| New file |
| | |
| | | from __future__ import print_function |
| | | import numpy as np |
| | | import os |
| | | import sys |
| | | import argparse |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import librosa |
| | | import soundfile as sf |
| | | from copy import deepcopy |
| | | import json |
| | | from tqdm import tqdm |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | def prepare(self, parser): |
| | | assert isinstance(parser, argparse.ArgumentParser) |
| | | parser.add_argument("wav_scp", type=str) |
| | | parser.add_argument("rttm_scp", type=str) |
| | | parser.add_argument("out_dir", type=str) |
| | | parser.add_argument("--min_dur", type=float, default=2.0) |
| | | parser.add_argument("--max_spk_num", type=int, default=4) |
| | | args = parser.parse_args() |
| | | |
| | | if not os.path.exists(args.out_dir): |
| | | os.makedirs(args.out_dir) |
| | | |
| | | wav_scp = load_scp_as_list(args.wav_scp) |
| | | rttm_scp = load_scp_as_dict(args.rttm_scp) |
| | | task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp] |
| | | return task_list, None, args |
| | | |
| | | def post(self, result_list, args): |
| | | count = [0, 0] |
| | | for result in result_list: |
| | | count[0] += result[0] |
| | | count[1] += result[1] |
| | | print("Found {} speakers, extracted {}.".format(count[1], count[0])) |
| | | |
| | | |
| | | # SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA> |
| | | def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4): |
| | | labels = np.zeros([max_spk_num, length], int) |
| | | spk_list = [] |
| | | for one_line in open(rttm_path, 'rt'): |
| | | parts = one_line.strip().split(" ") |
| | | mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] |
| | | if spk_name.isdigit(): |
| | | spk_name = "{}_S{:03d}".format(mid, int(spk_name)) |
| | | if spk_name not in spk_list: |
| | | spk_list.append(spk_name) |
| | | st, dur = int(st*sr), int(dur*sr) |
| | | idx = spk_list.index(spk_name) |
| | | labels[idx, st:st+dur] = 1 |
| | | return labels, spk_list |
| | | |
| | | |
| | | def get_nonoverlap_turns(multi_label, spk_list): |
| | | turns = [] |
| | | label = np.sum(multi_label, axis=0) == 1 |
| | | spk, in_turn, st = None, False, 0 |
| | | for i in range(len(label)): |
| | | if not in_turn and label[i]: |
| | | st, in_turn = i, True |
| | | spk = spk_list[np.argmax(multi_label[:, i], axis=0)] |
| | | if in_turn: |
| | | if not label[i]: |
| | | in_turn = False |
| | | turns.append([st, i, spk]) |
| | | elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]: |
| | | turns.append([st, i, spk]) |
| | | st, in_turn = i, True |
| | | spk = spk_list[np.argmax(multi_label[:, i], axis=0)] |
| | | if in_turn: |
| | | turns.append([st, len(label), spk]) |
| | | return turns |
| | | |
| | | |
| | | def process(task_args): |
| | | task_id, task_list, _, args = task_args |
| | | spk_count = [0, 0] |
| | | for mid, wav_path, rttm_path in task_list: |
| | | wav, sr = sf.read(wav_path, dtype="int16") |
| | | assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr) |
| | | multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num) |
| | | turns = get_nonoverlap_turns(multi_label, spk_list) |
| | | extracted_spk = [] |
| | | count = 1 |
| | | for st, ed, spk in tqdm(turns, total=len(turns), ascii=True): |
| | | if (ed - st) >= args.min_dur * args.sr: |
| | | seg = wav[st: ed] |
| | | save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count)) |
| | | if not os.path.exists(os.path.dirname(save_path)): |
| | | os.makedirs(os.path.dirname(save_path)) |
| | | sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) |
| | | count += 1 |
| | | if spk not in extracted_spk: |
| | | extracted_spk.append(spk) |
| | | if len(extracted_spk) != len(spk_list): |
| | | print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format( |
| | | mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk]) |
| | | )) |
| | | spk_count[0] += len(extracted_spk) |
| | | spk_count[1] += len(spk_list) |
| | | return spk_count |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import numpy as np |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import os |
| | | import librosa |
| | | import soundfile as sf |
| | | import argparse |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | |
| | | def prepare(self, parser): |
| | | parser.add_argument("dir", type=str) |
| | | parser.add_argument("out_dir", type=str) |
| | | args = parser.parse_args() |
| | | |
| | | meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp")) |
| | | vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8") |
| | | meeting2vad = {} |
| | | for one_line in vad_file: |
| | | uid, mid, st, ed = one_line.strip().split(" ") |
| | | st, ed = int(float(st) * args.sr), int(float(ed) * args.sr) |
| | | if mid not in meeting2vad: |
| | | meeting2vad[mid] = [] |
| | | meeting2vad[mid].append((uid, st, ed)) |
| | | |
| | | if not os.path.exists(args.out_dir): |
| | | os.makedirs(args.out_dir) |
| | | |
| | | task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp] |
| | | return task_list, None, args |
| | | |
| | | def post(self, results_list, args): |
| | | pass |
| | | |
| | | |
| | | def process(task_args): |
| | | _, task_list, _, args = task_args |
| | | for mid, wav_path, vad_list in task_list: |
| | | wav = librosa.load(wav_path, args.sr, True)[0] * 32767 |
| | | seg_list = [] |
| | | pos_map = [] |
| | | offset = 0 |
| | | for uid, st, ed in vad_list: |
| | | seg_list.append(wav[st: ed]) |
| | | pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st)) |
| | | offset = offset + ed - st |
| | | out = np.concatenate(seg_list, axis=0) |
| | | save_path = os.path.join(args.out_dir, "{}.wav".format(mid)) |
| | | sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) |
| | | map_path = os.path.join(args.out_dir, "{}.pos".format(mid)) |
| | | with open(map_path, "wt", encoding="utf-8") as fd: |
| | | fd.writelines(pos_map) |
| | | print(mid) |
| | | return None |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| New file |
| | |
| | | import logging |
| | | import numpy as np |
| | | import soundfile |
| | | import kaldiio |
| | | from funasr.utils.job_runner import MultiProcessRunnerV3 |
| | | from funasr.utils.misc import load_scp_as_list, load_scp_as_dict |
| | | import os |
| | | import argparse |
| | | from collections import OrderedDict |
| | | import random |
| | | from typing import List, Dict |
| | | from copy import deepcopy |
| | | import json |
| | | logging.basicConfig( |
| | | level="INFO", |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | |
| | | class MyRunner(MultiProcessRunnerV3): |
| | | |
| | | def prepare(self, parser: argparse.ArgumentParser): |
| | | parser.add_argument("--label_scp", type=str, required=True) |
| | | parser.add_argument("--wav_scp", type=str, required=True) |
| | | parser.add_argument("--utt2spk", type=str, required=True) |
| | | parser.add_argument("--spk2meeting", type=str, required=True) |
| | | parser.add_argument("--utt2xvec", type=str, required=True) |
| | | parser.add_argument("--out_dir", type=str, required=True) |
| | | parser.add_argument("--chunk_size", type=float, default=16) |
| | | parser.add_argument("--chunk_shift", type=float, default=4) |
| | | parser.add_argument("--frame_shift", type=float, default=0.01) |
| | | parser.add_argument("--embedding_dim", type=int, default=None) |
| | | parser.add_argument("--average_emb_num", type=int, default=0) |
| | | parser.add_argument("--subset", type=int, default=0) |
| | | parser.add_argument("--data_json", type=str, default=None) |
| | | parser.add_argument("--seed", type=int, default=1234) |
| | | parser.add_argument("--log_interval", type=int, default=100) |
| | | args = parser.parse_args() |
| | | random.seed(args.seed) |
| | | np.random.seed(args.seed) |
| | | |
| | | logging.info("Loading data...") |
| | | if not os.path.exists(args.data_json): |
| | | label_list = load_scp_as_list(args.label_scp) |
| | | wav_scp = load_scp_as_dict(args.wav_scp) |
| | | utt2spk = load_scp_as_dict(args.utt2spk) |
| | | utt2xvec = load_scp_as_dict(args.utt2xvec) |
| | | spk2meeting = load_scp_as_dict(args.spk2meeting) |
| | | |
| | | meeting2spks = OrderedDict() |
| | | for spk, meeting in spk2meeting.items(): |
| | | if meeting not in meeting2spks: |
| | | meeting2spks[meeting] = [] |
| | | meeting2spks[meeting].append(spk) |
| | | |
| | | spk2utts = OrderedDict() |
| | | for utt, spk in utt2spk.items(): |
| | | if spk not in spk2utts: |
| | | spk2utts[spk] = [] |
| | | spk2utts[spk].append(utt) |
| | | |
| | | os.makedirs(os.path.dirname(args.data_json), exist_ok=True) |
| | | logging.info("Dump data...") |
| | | json.dump({ |
| | | "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec, |
| | | "spk2utts": spk2utts, "meeting2spks": meeting2spks |
| | | }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4) |
| | | else: |
| | | data_dict = json.load(open(args.data_json, "rt", encoding="utf-8")) |
| | | label_list = data_dict["label_list"] |
| | | wav_scp = data_dict["wav_scp"] |
| | | utt2xvec = data_dict["utt2xvec"] |
| | | spk2utts = data_dict["spk2utts"] |
| | | meeting2spks = data_dict["meeting2spks"] |
| | | |
| | | if not os.path.exists(args.out_dir): |
| | | os.makedirs(args.out_dir) |
| | | |
| | | args.chunk_size = int(args.chunk_size / args.frame_shift) |
| | | args.chunk_shift = int(args.chunk_shift / args.frame_shift) |
| | | |
| | | if args.embedding_dim is None: |
| | | args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1] |
| | | logging.info("Embedding dim is detected as {}.".format(args.embedding_dim)) |
| | | |
| | | logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format( |
| | | len(wav_scp), len(spk2utts), len(meeting2spks) |
| | | )) |
| | | return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args |
| | | |
| | | def post(self, results_list, args): |
| | | logging.info("[main]: Got {} chunks.".format(sum(results_list))) |
| | | |
| | | |
| | | def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length): |
| | | utt_list = spk2utts[spk] |
| | | wav_list = [] |
| | | cur_length = 0 |
| | | while cur_length < sample_length: |
| | | uttid = random.choice(utt_list) |
| | | wav, fs = soundfile.read(wav_scp[uttid], dtype='float32') |
| | | wav_list.append(wav) |
| | | cur_length += len(wav) |
| | | concat_wav = np.concatenate(wav_list, axis=0) |
| | | start = random.randint(0, len(concat_wav) - sample_length) |
| | | return concat_wav[start: start+sample_length] |
| | | |
| | | |
| | | def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num): |
| | | # process for dummy speaker |
| | | if spk == "None": |
| | | return np.zeros((1, embedding_dim), dtype=np.float32) |
| | | |
| | | # calculate averaged speaker embeddings |
| | | utt_list = spk2utts[spk] |
| | | if average_emb_num == 0 or average_emb_num > len(utt_list): |
| | | xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list] |
| | | else: |
| | | xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)] |
| | | xvec = np.concatenate(xvec_list, axis=0) |
| | | xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True) |
| | | xvec = np.mean(xvec, axis=0) |
| | | |
| | | return xvec |
| | | |
| | | |
| | | def simu_chunk( |
| | | frame_label: np.ndarray, |
| | | sample_label: np.ndarray, |
| | | wav_scp: Dict[str, str], |
| | | utt2xvec: Dict[str, str], |
| | | spk2utts: Dict[str, List[str]], |
| | | meeting2spks: Dict[str, List[str]], |
| | | all_speaker_list: List[str], |
| | | meeting_list: List[str], |
| | | embedding_dim: int, |
| | | average_emb_num: int, |
| | | ): |
| | | frame_length, max_spk_num = frame_label.shape |
| | | sample_length = sample_label.shape[0] |
| | | positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0)) |
| | | pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)]) |
| | | |
| | | # get positive speakers |
| | | if len(pos_speaker_list) >= positive_speaker_num: |
| | | pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num) |
| | | else: |
| | | while len(pos_speaker_list) < positive_speaker_num: |
| | | _spk = random.choice(all_speaker_list) |
| | | if _spk not in pos_speaker_list: |
| | | pos_speaker_list.append(_spk) |
| | | |
| | | # get negative speakers |
| | | negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num) |
| | | neg_speaker_list = [] |
| | | while len(neg_speaker_list) < negative_speaker_num: |
| | | _spk = random.choice(all_speaker_list) |
| | | if _spk not in pos_speaker_list and _spk not in neg_speaker_list: |
| | | neg_speaker_list.append(_spk) |
| | | neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num)) |
| | | |
| | | random.shuffle(pos_speaker_list) |
| | | random.shuffle(neg_speaker_list) |
| | | seperated_wav = np.zeros(sample_label.shape, dtype=np.float32) |
| | | this_spk_list = [] |
| | | for idx, frame_num in enumerate(frame_label.sum(axis=0)): |
| | | if frame_num > 0: |
| | | spk = pos_speaker_list.pop(0) |
| | | this_spk_list.append(spk) |
| | | simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length) |
| | | seperated_wav[:, idx] = simu_spk_wav |
| | | else: |
| | | spk = neg_speaker_list.pop(0) |
| | | this_spk_list.append(spk) |
| | | |
| | | # calculate mixed wav |
| | | mixed_wav = np.sum(seperated_wav * sample_label, axis=1) |
| | | |
| | | # shuffle the order of speakers |
| | | shuffle_idx = list(range(max_spk_num)) |
| | | random.shuffle(shuffle_idx) |
| | | this_spk_list = [this_spk_list[x] for x in shuffle_idx] |
| | | seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose() |
| | | frame_label = frame_label.transpose()[shuffle_idx].transpose() |
| | | |
| | | # calculate profile |
| | | profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num) |
| | | for spk in this_spk_list] |
| | | profile = np.vstack(profile) |
| | | # pse_weights = 2 ** np.arange(max_spk_num) |
| | | # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1) |
| | | # pse_label = pse_label.astype(str).tolist() |
| | | |
| | | return mixed_wav, seperated_wav, profile, frame_label |
| | | |
| | | |
| | | def process(task_args): |
| | | task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args |
| | | logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj)) |
| | | |
| | | out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1)) |
| | | wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) |
| | | |
| | | # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1)) |
| | | # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) |
| | | |
| | | out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1)) |
| | | profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) |
| | | |
| | | out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1)) |
| | | label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) |
| | | |
| | | speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys()) |
| | | |
| | | labels_list = [] |
| | | total_chunks = 0 |
| | | for org_mid, label_path in task_list: |
| | | whole_label = kaldiio.load_mat(label_path) |
| | | # random offset to keep diversity |
| | | rand_shift = random.randint(0, args.chunk_shift) |
| | | num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1 |
| | | labels_list.append((org_mid, whole_label, rand_shift, num_chunk)) |
| | | total_chunks += num_chunk |
| | | |
| | | idx = 0 |
| | | simu_chunk_count = 0 |
| | | for org_mid, whole_label, rand_shift, num_chunk in labels_list: |
| | | for i in range(num_chunk): |
| | | idx = idx + 1 |
| | | st = i * args.chunk_shift + rand_shift |
| | | ed = i * args.chunk_shift + args.chunk_size + rand_shift |
| | | utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format( |
| | | args.subset + 1, task_idx + 1, org_mid, st, ed |
| | | ) |
| | | frame_label = whole_label[st: ed, :] |
| | | sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0) |
| | | mix_wav, seg_wav, profile, frame_label = simu_chunk( |
| | | frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks, |
| | | speaker_list, meeting_list, args.embedding_dim, args.average_emb_num |
| | | ) |
| | | wav_mix_writer(utt_id, mix_wav) |
| | | # wav_sep_writer(utt_id, seg_wav) |
| | | profile_writer(utt_id, profile) |
| | | label_writer(utt_id, frame_label) |
| | | |
| | | simu_chunk_count += 1 |
| | | if simu_chunk_count % args.log_interval == 0: |
| | | logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format( |
| | | task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id)) |
| | | wav_mix_writer.close() |
| | | # wav_sep_writer.close() |
| | | profile_writer.close() |
| | | label_writer.close() |
| | | logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count)) |
| | | return simu_chunk_count |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | my_runner = MyRunner(process) |
| | | my_runner.run() |
| | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sond": |
| | | from funasr.bin.sond_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | return inference_modelscope(mode=mode, **kwargs) |
| | | elif mode == "sond_demo": |
| | | from funasr.bin.sond_inference import inference_modelscope |
| | | param_dict = { |
| | |
| | | "sv_train_config": "sv.yaml", |
| | | "sv_model_file": "sv.pth", |
| | | } |
| | | if "param_dict" in kwargs: |
| | | kwargs["param_dict"].update(param_dict) |
| | | if "param_dict" in kwargs and kwargs["param_dict"] is not None: |
| | | for key in param_dict: |
| | | if key not in kwargs["param_dict"]: |
| | | kwargs["param_dict"][key] = param_dict[key] |
| | | else: |
| | | kwargs["param_dict"] = param_dict |
| | | return inference_modelscope(**kwargs) |
| | | return inference_modelscope(mode=mode, **kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | |
| | | import os |
| | | |
| | | from funasr.tasks.diar import DiarTask |
| | | |
| | | |
| | | # for ASR Training |
| | | def parse_args(): |
| | | parser = DiarTask.get_parser() |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(args=None, cmd=None): |
| | | # for ASR Training |
| | | DiarTask.main(args=args, cmd=cmd) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | args = parse_args() |
| | | |
| | | # setup local gpu_id |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | |
| | | # DDP settings |
| | | if args.ngpu > 1: |
| | | args.distributed = True |
| | | else: |
| | | args.distributed = False |
| | | assert args.num_worker_count == 1 |
| | | |
| | | # re-compute batch size: when dataset type is small |
| | | if args.dataset_type == "small": |
| | | if args.batch_size is not None: |
| | | args.batch_size = args.batch_size * args.ngpu |
| | | if args.batch_bins is not None: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | main(args=args) |
| | |
| | | from funasr.utils.types import str_or_none |
| | | from scipy.ndimage import median_filter |
| | | from funasr.utils.misc import statistic_model_parameters |
| | | from funasr.datasets.iterable_dataset import load_bytes |
| | | |
| | | |
| | | class Speech2Diarization: |
| | | """Speech2Xvector class |
| | |
| | | dur_threshold: int = 10, |
| | | out_format: str = "vad", |
| | | param_dict: Optional[dict] = None, |
| | | mode: str = "sond", |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2a. Build speech2xvec [Optional] |
| | | if param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]: |
| | | if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]: |
| | | assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict." |
| | | assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict." |
| | | sv_train_config = param_dict["sv_train_config"] |
| | | sv_model_file = param_dict["sv_model_file"] |
| | | if "model_dir" in param_dict: |
| | | sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config) |
| | | sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file) |
| | | from funasr.bin.sv_inference import Speech2Xvector |
| | | speech2xvector_kwargs = dict( |
| | | sv_train_config=sv_train_config, |
| | |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, |
| | | raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None, |
| | | raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | param_dict: Optional[dict] = None, |
| | | ): |
| | | logging.info("param_dict: {}".format(param_dict)) |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, (list, tuple)): |
| | | if not isinstance(raw_inputs[0], List): |
| | | raw_inputs = [raw_inputs] |
| | | |
| | | assert all([len(example) >= 2 for example in raw_inputs]), \ |
| | | "The length of test case in raw_inputs must larger than 1 (>=2)." |
| | | |
| | | def prepare_dataset(): |
| | | for idx, example in enumerate(raw_inputs): |
| | | # read waveform file |
| | | example = [soundfile.read(x)[0] if isinstance(example[0], str) else x |
| | | example = [load_bytes(x) if isinstance(x, bytes) else x |
| | | for x in example] |
| | | example = [soundfile.read(x)[0] if isinstance(x, str) else x |
| | | for x in example] |
| | | # convert torch tensor to numpy array |
| | | example = [x.numpy() if isinstance(example[0], torch.Tensor) else x |
| | |
| | | else: |
| | | olens = None |
| | | |
| | | return output, olens |
| | | return output.to(input.dtype), olens |
| | |
| | | |
| | | import torch |
| | | from torch import nn |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class LabelSmoothingLoss(nn.Module): |
| | |
| | | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) |
| | | denom = total if self.normalize_length else batch_size |
| | | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom |
| | | |
| | | |
| | | class SequenceBinaryCrossEntropy(nn.Module): |
| | | def __init__( |
| | | self, |
| | | normalize_length=False, |
| | | criterion=nn.BCEWithLogitsLoss(reduction="none") |
| | | ): |
| | | super().__init__() |
| | | self.normalize_length = normalize_length |
| | | self.criterion = criterion |
| | | |
| | | def forward(self, pred, label, lengths): |
| | | pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]) |
| | | loss = self.criterion(pred, label) |
| | | denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] |
| | | return loss.masked_fill(pad_mask, 0).sum() / denom |
| | |
| | | from itertools import permutations |
| | | from typing import Dict |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Tuple, List |
| | | |
| | | import numpy as np |
| | | import torch |
| | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy |
| | | from funasr.utils.misc import int2vec |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: AbsEncoder, |
| | | speaker_encoder: AbsEncoder, |
| | | encoder: torch.nn.Module, |
| | | speaker_encoder: Optional[torch.nn.Module], |
| | | ci_scorer: torch.nn.Module, |
| | | cd_scorer: torch.nn.Module, |
| | | cd_scorer: Optional[torch.nn.Module], |
| | | decoder: torch.nn.Module, |
| | | token_list: list, |
| | | lsm_weight: float = 0.1, |
| | | length_normalized_loss: bool = False, |
| | | max_spk_num: int = 16, |
| | | label_aggregator: Optional[torch.nn.Module] = None, |
| | | normlize_speech_speaker: bool = False, |
| | | normalize_speech_speaker: bool = False, |
| | | ignore_id: int = -1, |
| | | speaker_discrimination_loss_weight: float = 1.0, |
| | | inter_score_loss_weight: float = 0.0 |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | |
| | | self.decoder = decoder |
| | | self.token_list = token_list |
| | | self.max_spk_num = max_spk_num |
| | | self.normalize_speech_speaker = normlize_speech_speaker |
| | | self.normalize_speech_speaker = normalize_speech_speaker |
| | | self.ignore_id = ignore_id |
| | | self.criterion_diar = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss) |
| | | self.pse_embedding = self.generate_pse_embedding() |
| | | # self.register_buffer("pse_embedding", pse_embedding) |
| | | self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() |
| | | # self.register_buffer("power_weight", power_weight) |
| | | self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() |
| | | # self.register_buffer("int_token_arr", int_token_arr) |
| | | self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight |
| | | self.inter_score_loss_weight = inter_score_loss_weight |
| | | self.forward_steps = 0 |
| | | |
| | | def generate_pse_embedding(self): |
| | | embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float) |
| | | for idx, pse_label in enumerate(self.token_list): |
| | | emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float) |
| | | embedding[idx] = emb |
| | | return torch.from_numpy(embedding) |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | speech_lengths: torch.Tensor = None, |
| | | profile: torch.Tensor = None, |
| | | profile_lengths: torch.Tensor = None, |
| | | spk_labels: torch.Tensor = None, |
| | | spk_labels_lengths: torch.Tensor = None, |
| | | binary_labels: torch.Tensor = None, |
| | | binary_labels_lengths: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss |
| | | |
| | | Args: |
| | | speech: (Batch, samples) |
| | | speech: (Batch, samples) or (Batch, frames, input_size) |
| | | speech_lengths: (Batch,) default None for chunk interator, |
| | | because the chunk-iterator does not |
| | | have the speech_lengths returned. |
| | |
| | | espnet2/iterators/chunk_iter_factory.py |
| | | profile: (Batch, N_spk, dim) |
| | | profile_lengths: (Batch,) |
| | | spk_labels: (Batch, ) |
| | | binary_labels: (Batch, frames, max_spk_num) |
| | | binary_labels_lengths: (Batch,) |
| | | """ |
| | | assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape) |
| | | assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | batch_size = speech.shape[0] |
| | | self.forward_steps = self.forward_steps + 1 |
| | | # 1. Network forward |
| | | pred, inter_outputs = self.prediction_forward( |
| | | speech, speech_lengths, |
| | | profile, profile_lengths, |
| | | return_inter_outputs=True |
| | | ) |
| | | (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | if self.attractor is None: |
| | | # 2a. Decoder (baiscally a predction layer after encoder_out) |
| | | pred = self.decoder(encoder_out, encoder_out_lens) |
| | | else: |
| | | # 2b. Encoder Decoder Attractors |
| | | # Shuffle the chronological order of encoder_out, then calculate attractor |
| | | encoder_out_shuffled = encoder_out.clone() |
| | | for i in range(len(encoder_out_lens)): |
| | | encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[ |
| | | i, torch.randperm(encoder_out_lens[i]), : |
| | | ] |
| | | attractor, att_prob = self.attractor( |
| | | encoder_out_shuffled, |
| | | encoder_out_lens, |
| | | to_device( |
| | | self, |
| | | torch.zeros( |
| | | encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2) |
| | | ), |
| | | ), |
| | | ) |
| | | # Remove the final attractor which does not correspond to a speaker |
| | | # Then multiply the attractors and encoder_out |
| | | pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1)) |
| | | # 3. Aggregate time-domain labels |
| | | # 2. Aggregate time-domain labels to match forward outputs |
| | | if self.label_aggregator is not None: |
| | | spk_labels, spk_labels_lengths = self.label_aggregator( |
| | | spk_labels, spk_labels_lengths |
| | | binary_labels, binary_labels_lengths = self.label_aggregator( |
| | | binary_labels, binary_labels_lengths |
| | | ) |
| | | # 2. Calculate power-set encoding (PSE) labels |
| | | raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True) |
| | | pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2) |
| | | |
| | | # If encoder uses conv* as input_layer (i.e., subsampling), |
| | | # the sequence length of 'pred' might be slighly less than the |
| | | # the sequence length of 'pred' might be slightly less than the |
| | | # length of 'spk_labels'. Here we force them to be equal. |
| | | length_diff_tolerance = 2 |
| | | length_diff = spk_labels.shape[1] - pred.shape[1] |
| | | if length_diff > 0 and length_diff <= length_diff_tolerance: |
| | | spk_labels = spk_labels[:, 0 : pred.shape[1], :] |
| | | length_diff = pse_labels.shape[1] - pred.shape[1] |
| | | if 0 < length_diff <= length_diff_tolerance: |
| | | pse_labels = pse_labels[:, 0: pred.shape[1]] |
| | | |
| | | if self.attractor is None: |
| | | loss_pit, loss_att = None, None |
| | | loss, perm_idx, perm_list, label_perm = self.pit_loss( |
| | | pred, spk_labels, encoder_out_lens |
| | | ) |
| | | else: |
| | | loss_pit, perm_idx, perm_list, label_perm = self.pit_loss( |
| | | pred, spk_labels, encoder_out_lens |
| | | ) |
| | | loss_att = self.attractor_loss(att_prob, spk_labels) |
| | | loss = loss_pit + self.attractor_weight * loss_att |
| | | loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) |
| | | loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) |
| | | loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths) |
| | | label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device) |
| | | loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis |
| | | + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd)) |
| | | |
| | | ( |
| | | correct, |
| | | num_frames, |
| | |
| | | speaker_miss, |
| | | speaker_falarm, |
| | | speaker_error, |
| | | ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens) |
| | | ) = self.calc_diarization_error( |
| | | pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding), |
| | | label=F.embedding(pse_labels * (~label_mask), self.pse_embedding), |
| | | length=binary_labels_lengths |
| | | ) |
| | | |
| | | if speech_scored > 0 and num_frames > 0: |
| | | sad_mr, sad_fr, mi, fa, cf, acc, der = ( |
| | |
| | | |
| | | stats = dict( |
| | | loss=loss.detach(), |
| | | loss_att=loss_att.detach() if loss_att is not None else None, |
| | | loss_pit=loss_pit.detach() if loss_pit is not None else None, |
| | | loss_diar=loss_diar.detach() if loss_diar is not None else None, |
| | | loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None, |
| | | loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None, |
| | | loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None, |
| | | sad_mr=sad_mr, |
| | | sad_fr=sad_fr, |
| | | mi=mi, |
| | |
| | | cf=cf, |
| | | acc=acc, |
| | | der=der, |
| | | forward_steps=self.forward_steps, |
| | | ) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def classification_loss( |
| | | self, |
| | | predictions: torch.Tensor, |
| | | labels: torch.Tensor, |
| | | prediction_lengths: torch.Tensor |
| | | ) -> torch.Tensor: |
| | | mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1]) |
| | | pad_labels = labels.masked_fill( |
| | | mask.to(predictions.device), |
| | | value=self.ignore_id |
| | | ) |
| | | loss = self.criterion_diar(predictions.contiguous(), pad_labels) |
| | | |
| | | return loss |
| | | |
| | | def speaker_discrimination_loss( |
| | | self, |
| | | profile: torch.Tensor, |
| | | profile_lengths: torch.Tensor |
| | | ) -> torch.Tensor: |
| | | profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float() # (B, N, 1) |
| | | mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N) |
| | | mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask)) |
| | | |
| | | eps = 1e-12 |
| | | coding_norm = torch.linalg.norm( |
| | | profile * profile_mask + (1 - profile_mask) * eps, |
| | | dim=2, keepdim=True |
| | | ) * profile_mask |
| | | # profile: Batch, N, dim |
| | | cos_theta = F.cosine_similarity(profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps) * mask |
| | | cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps) |
| | | loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum() |
| | | |
| | | return loss |
| | | |
| | | def calculate_multi_labels(self, pse_labels, pse_labels_lengths): |
| | | mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1]) |
| | | padding_labels = pse_labels.masked_fill( |
| | | mask.to(pse_labels.device), |
| | | value=0 |
| | | ).to(pse_labels) |
| | | multi_labels = F.embedding(padding_labels, self.pse_embedding) |
| | | |
| | | return multi_labels |
| | | |
| | | def internal_score_loss( |
| | | self, |
| | | cd_score: torch.Tensor, |
| | | ci_score: torch.Tensor, |
| | | pse_labels: torch.Tensor, |
| | | pse_labels_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths) |
| | | ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths) |
| | | cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths) |
| | | return ci_loss, cd_loss |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | spk_labels: torch.Tensor = None, |
| | | spk_labels_lengths: torch.Tensor = None, |
| | | profile: torch.Tensor = None, |
| | | profile_lengths: torch.Tensor = None, |
| | | binary_labels: torch.Tensor = None, |
| | | binary_labels_lengths: torch.Tensor = None, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | |
| | | speaker_encoder_outputs: torch.Tensor, |
| | | seq_len: torch.Tensor = None, |
| | | spk_len: torch.Tensor = None, |
| | | ) -> torch.Tensor: |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1] |
| | | d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2] |
| | | if self.normalize_speech_speaker: |
| | |
| | | ci_simi = self.ci_scorer(ge_in, ge_len)[0] |
| | | else: |
| | | ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs) |
| | | simi = torch.cat([cd_simi, ci_simi], dim=2) |
| | | |
| | | return simi |
| | | return ci_simi, cd_simi |
| | | |
| | | def post_net_forward(self, simi, seq_len): |
| | | logits = self.decoder(simi, seq_len)[0] |
| | |
| | | speech_lengths: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | profile_lengths: torch.Tensor, |
| | | ) -> torch.Tensor: |
| | | return_inter_outputs: bool = False, |
| | | ) -> [torch.Tensor, Optional[list]]: |
| | | # speech encoding |
| | | speech, speech_lengths = self.encode_speech(speech, speech_lengths) |
| | | # speaker encoding |
| | | profile, profile_lengths = self.encode_speaker(profile, profile_lengths) |
| | | # calculating similarity |
| | | similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths) |
| | | ci_simi, cd_simi = self.calc_similarity(speech, profile, speech_lengths, profile_lengths) |
| | | similarity = torch.cat([cd_simi, ci_simi], dim=2) |
| | | # post net forward |
| | | logits = self.post_net_forward(similarity, speech_lengths) |
| | | |
| | | if return_inter_outputs: |
| | | return logits, [(speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi)] |
| | | return logits |
| | | |
| | | def encode( |
| | |
| | | # 4. Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim) |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | encoder_outputs = self.encoder(feats, feats_lengths) |
| | | encoder_out, encoder_out_lens = encoder_outputs[:2] |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | |
| | | |
| | | (batch_size, max_len, num_output) = label.size() |
| | | # mask the padding part |
| | | mask = np.zeros((batch_size, max_len, num_output)) |
| | | for i in range(batch_size): |
| | | mask[i, : length[i], :] = 1 |
| | | mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy() |
| | | |
| | | # pred and label have the shape (batch_size, max_len, num_output) |
| | | label_np = label.data.cpu().numpy().astype(int) |
| New file |
| | |
| | | import math |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | |
| | | class _BatchNorm1d(nn.Module): |
| | | def __init__( |
| | | self, |
| | | input_shape=None, |
| | | input_size=None, |
| | | eps=1e-05, |
| | | momentum=0.1, |
| | | affine=True, |
| | | track_running_stats=True, |
| | | combine_batch_time=False, |
| | | skip_transpose=False, |
| | | ): |
| | | super().__init__() |
| | | self.combine_batch_time = combine_batch_time |
| | | self.skip_transpose = skip_transpose |
| | | |
| | | if input_size is None and skip_transpose: |
| | | input_size = input_shape[1] |
| | | elif input_size is None: |
| | | input_size = input_shape[-1] |
| | | |
| | | self.norm = nn.BatchNorm1d( |
| | | input_size, |
| | | eps=eps, |
| | | momentum=momentum, |
| | | affine=affine, |
| | | track_running_stats=track_running_stats, |
| | | ) |
| | | |
| | | def forward(self, x): |
| | | shape_or = x.shape |
| | | if self.combine_batch_time: |
| | | if x.ndim == 3: |
| | | x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) |
| | | else: |
| | | x = x.reshape( |
| | | shape_or[0] * shape_or[1], shape_or[3], shape_or[2] |
| | | ) |
| | | |
| | | elif not self.skip_transpose: |
| | | x = x.transpose(-1, 1) |
| | | |
| | | x_n = self.norm(x) |
| | | |
| | | if self.combine_batch_time: |
| | | x_n = x_n.reshape(shape_or) |
| | | elif not self.skip_transpose: |
| | | x_n = x_n.transpose(1, -1) |
| | | |
| | | return x_n |
| | | |
| | | |
| | | class _Conv1d(nn.Module): |
| | | def __init__( |
| | | self, |
| | | out_channels, |
| | | kernel_size, |
| | | input_shape=None, |
| | | in_channels=None, |
| | | stride=1, |
| | | dilation=1, |
| | | padding="same", |
| | | groups=1, |
| | | bias=True, |
| | | padding_mode="reflect", |
| | | skip_transpose=False, |
| | | ): |
| | | super().__init__() |
| | | self.kernel_size = kernel_size |
| | | self.stride = stride |
| | | self.dilation = dilation |
| | | self.padding = padding |
| | | self.padding_mode = padding_mode |
| | | self.unsqueeze = False |
| | | self.skip_transpose = skip_transpose |
| | | |
| | | if input_shape is None and in_channels is None: |
| | | raise ValueError("Must provide one of input_shape or in_channels") |
| | | |
| | | if in_channels is None: |
| | | in_channels = self._check_input_shape(input_shape) |
| | | |
| | | self.conv = nn.Conv1d( |
| | | in_channels, |
| | | out_channels, |
| | | self.kernel_size, |
| | | stride=self.stride, |
| | | dilation=self.dilation, |
| | | padding=0, |
| | | groups=groups, |
| | | bias=bias, |
| | | ) |
| | | |
| | | def forward(self, x): |
| | | if not self.skip_transpose: |
| | | x = x.transpose(1, -1) |
| | | |
| | | if self.unsqueeze: |
| | | x = x.unsqueeze(1) |
| | | |
| | | if self.padding == "same": |
| | | x = self._manage_padding( |
| | | x, self.kernel_size, self.dilation, self.stride |
| | | ) |
| | | |
| | | elif self.padding == "causal": |
| | | num_pad = (self.kernel_size - 1) * self.dilation |
| | | x = F.pad(x, (num_pad, 0)) |
| | | |
| | | elif self.padding == "valid": |
| | | pass |
| | | |
| | | else: |
| | | raise ValueError( |
| | | "Padding must be 'same', 'valid' or 'causal'. Got " |
| | | + self.padding |
| | | ) |
| | | |
| | | wx = self.conv(x) |
| | | |
| | | if self.unsqueeze: |
| | | wx = wx.squeeze(1) |
| | | |
| | | if not self.skip_transpose: |
| | | wx = wx.transpose(1, -1) |
| | | |
| | | return wx |
| | | |
| | | def _manage_padding( |
| | | self, x, kernel_size: int, dilation: int, stride: int, |
| | | ): |
| | | # Detecting input shape |
| | | L_in = x.shape[-1] |
| | | |
| | | # Time padding |
| | | padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
| | | |
| | | # Applying padding |
| | | x = F.pad(x, padding, mode=self.padding_mode) |
| | | |
| | | return x |
| | | |
| | | def _check_input_shape(self, shape): |
| | | """Checks the input shape and returns the number of input channels. |
| | | """ |
| | | |
| | | if len(shape) == 2: |
| | | self.unsqueeze = True |
| | | in_channels = 1 |
| | | elif self.skip_transpose: |
| | | in_channels = shape[1] |
| | | elif len(shape) == 3: |
| | | in_channels = shape[2] |
| | | else: |
| | | raise ValueError( |
| | | "conv1d expects 2d, 3d inputs. Got " + str(len(shape)) |
| | | ) |
| | | |
| | | # Kernel size must be odd |
| | | if self.kernel_size % 2 == 0: |
| | | raise ValueError( |
| | | "The field kernel size must be an odd number. Got %s." |
| | | % (self.kernel_size) |
| | | ) |
| | | return in_channels |
| | | |
| | | |
| | | def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
| | | if stride > 1: |
| | | n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) |
| | | L_out = stride * (n_steps - 1) + kernel_size * dilation |
| | | padding = [kernel_size // 2, kernel_size // 2] |
| | | |
| | | else: |
| | | L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 |
| | | |
| | | padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] |
| | | return padding |
| | | |
| | | |
| | | # Skip transpose as much as possible for efficiency |
| | | class Conv1d(_Conv1d): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(skip_transpose=True, *args, **kwargs) |
| | | |
| | | |
| | | class BatchNorm1d(_BatchNorm1d): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(skip_transpose=True, *args, **kwargs) |
| | | |
| | | |
| | | def length_to_mask(length, max_len=None, dtype=None, device=None): |
| | | assert len(length.shape) == 1 |
| | | |
| | | if max_len is None: |
| | | max_len = length.max().long().item() # using arange to generate mask |
| | | mask = torch.arange( |
| | | max_len, device=length.device, dtype=length.dtype |
| | | ).expand(len(length), max_len) < length.unsqueeze(1) |
| | | |
| | | if dtype is None: |
| | | dtype = length.dtype |
| | | |
| | | if device is None: |
| | | device = length.device |
| | | |
| | | mask = torch.as_tensor(mask, dtype=dtype, device=device) |
| | | return mask |
| | | |
| | | |
| | | class TDNNBlock(nn.Module): |
| | | def __init__( |
| | | self, |
| | | in_channels, |
| | | out_channels, |
| | | kernel_size, |
| | | dilation, |
| | | activation=nn.ReLU, |
| | | groups=1, |
| | | ): |
| | | super(TDNNBlock, self).__init__() |
| | | self.conv = Conv1d( |
| | | in_channels=in_channels, |
| | | out_channels=out_channels, |
| | | kernel_size=kernel_size, |
| | | dilation=dilation, |
| | | groups=groups, |
| | | ) |
| | | self.activation = activation() |
| | | self.norm = BatchNorm1d(input_size=out_channels) |
| | | |
| | | def forward(self, x): |
| | | return self.norm(self.activation(self.conv(x))) |
| | | |
| | | |
| | | class Res2NetBlock(torch.nn.Module): |
| | | """An implementation of Res2NetBlock w/ dilation. |
| | | |
| | | Arguments |
| | | --------- |
| | | in_channels : int |
| | | The number of channels expected in the input. |
| | | out_channels : int |
| | | The number of output channels. |
| | | scale : int |
| | | The scale of the Res2Net block. |
| | | kernel_size: int |
| | | The kernel size of the Res2Net block. |
| | | dilation : int |
| | | The dilation of the Res2Net block. |
| | | |
| | | Example |
| | | ------- |
| | | >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
| | | >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) |
| | | >>> out_tensor = layer(inp_tensor).transpose(1, 2) |
| | | >>> out_tensor.shape |
| | | torch.Size([8, 120, 64]) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1 |
| | | ): |
| | | super(Res2NetBlock, self).__init__() |
| | | assert in_channels % scale == 0 |
| | | assert out_channels % scale == 0 |
| | | |
| | | in_channel = in_channels // scale |
| | | hidden_channel = out_channels // scale |
| | | |
| | | self.blocks = nn.ModuleList( |
| | | [ |
| | | TDNNBlock( |
| | | in_channel, |
| | | hidden_channel, |
| | | kernel_size=kernel_size, |
| | | dilation=dilation, |
| | | ) |
| | | for i in range(scale - 1) |
| | | ] |
| | | ) |
| | | self.scale = scale |
| | | |
| | | def forward(self, x): |
| | | y = [] |
| | | for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): |
| | | if i == 0: |
| | | y_i = x_i |
| | | elif i == 1: |
| | | y_i = self.blocks[i - 1](x_i) |
| | | else: |
| | | y_i = self.blocks[i - 1](x_i + y_i) |
| | | y.append(y_i) |
| | | y = torch.cat(y, dim=1) |
| | | return y |
| | | |
| | | |
| | | class SEBlock(nn.Module): |
| | | """An implementation of squeeze-and-excitation block. |
| | | |
| | | Arguments |
| | | --------- |
| | | in_channels : int |
| | | The number of input channels. |
| | | se_channels : int |
| | | The number of output channels after squeeze. |
| | | out_channels : int |
| | | The number of output channels. |
| | | |
| | | Example |
| | | ------- |
| | | >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
| | | >>> se_layer = SEBlock(64, 16, 64) |
| | | >>> lengths = torch.rand((8,)) |
| | | >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2) |
| | | >>> out_tensor.shape |
| | | torch.Size([8, 120, 64]) |
| | | """ |
| | | |
| | | def __init__(self, in_channels, se_channels, out_channels): |
| | | super(SEBlock, self).__init__() |
| | | |
| | | self.conv1 = Conv1d( |
| | | in_channels=in_channels, out_channels=se_channels, kernel_size=1 |
| | | ) |
| | | self.relu = torch.nn.ReLU(inplace=True) |
| | | self.conv2 = Conv1d( |
| | | in_channels=se_channels, out_channels=out_channels, kernel_size=1 |
| | | ) |
| | | self.sigmoid = torch.nn.Sigmoid() |
| | | |
| | | def forward(self, x, lengths=None): |
| | | L = x.shape[-1] |
| | | if lengths is not None: |
| | | mask = length_to_mask(lengths * L, max_len=L, device=x.device) |
| | | mask = mask.unsqueeze(1) |
| | | total = mask.sum(dim=2, keepdim=True) |
| | | s = (x * mask).sum(dim=2, keepdim=True) / total |
| | | else: |
| | | s = x.mean(dim=2, keepdim=True) |
| | | |
| | | s = self.relu(self.conv1(s)) |
| | | s = self.sigmoid(self.conv2(s)) |
| | | |
| | | return s * x |
| | | |
| | | |
| | | class AttentiveStatisticsPooling(nn.Module): |
| | | """This class implements an attentive statistic pooling layer for each channel. |
| | | It returns the concatenated mean and std of the input tensor. |
| | | |
| | | Arguments |
| | | --------- |
| | | channels: int |
| | | The number of input channels. |
| | | attention_channels: int |
| | | The number of attention channels. |
| | | |
| | | Example |
| | | ------- |
| | | >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
| | | >>> asp_layer = AttentiveStatisticsPooling(64) |
| | | >>> lengths = torch.rand((8,)) |
| | | >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2) |
| | | >>> out_tensor.shape |
| | | torch.Size([8, 1, 128]) |
| | | """ |
| | | |
| | | def __init__(self, channels, attention_channels=128, global_context=True): |
| | | super().__init__() |
| | | |
| | | self.eps = 1e-12 |
| | | self.global_context = global_context |
| | | if global_context: |
| | | self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) |
| | | else: |
| | | self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) |
| | | self.tanh = nn.Tanh() |
| | | self.conv = Conv1d( |
| | | in_channels=attention_channels, out_channels=channels, kernel_size=1 |
| | | ) |
| | | |
| | | def forward(self, x, lengths=None): |
| | | """Calculates mean and std for a batch (input tensor). |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Tensor of shape [N, C, L]. |
| | | """ |
| | | L = x.shape[-1] |
| | | |
| | | def _compute_statistics(x, m, dim=2, eps=self.eps): |
| | | mean = (m * x).sum(dim) |
| | | std = torch.sqrt( |
| | | (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps) |
| | | ) |
| | | return mean, std |
| | | |
| | | if lengths is None: |
| | | lengths = torch.ones(x.shape[0], device=x.device) |
| | | |
| | | # Make binary mask of shape [N, 1, L] |
| | | mask = length_to_mask(lengths * L, max_len=L, device=x.device) |
| | | mask = mask.unsqueeze(1) |
| | | |
| | | # Expand the temporal context of the pooling layer by allowing the |
| | | # self-attention to look at global properties of the utterance. |
| | | if self.global_context: |
| | | # torch.std is unstable for backward computation |
| | | # https://github.com/pytorch/pytorch/issues/4320 |
| | | total = mask.sum(dim=2, keepdim=True).float() |
| | | mean, std = _compute_statistics(x, mask / total) |
| | | mean = mean.unsqueeze(2).repeat(1, 1, L) |
| | | std = std.unsqueeze(2).repeat(1, 1, L) |
| | | attn = torch.cat([x, mean, std], dim=1) |
| | | else: |
| | | attn = x |
| | | |
| | | # Apply layers |
| | | attn = self.conv(self.tanh(self.tdnn(attn))) |
| | | |
| | | # Filter out zero-paddings |
| | | attn = attn.masked_fill(mask == 0, float("-inf")) |
| | | |
| | | attn = F.softmax(attn, dim=2) |
| | | mean, std = _compute_statistics(x, attn) |
| | | # Append mean and std of the batch |
| | | pooled_stats = torch.cat((mean, std), dim=1) |
| | | pooled_stats = pooled_stats.unsqueeze(2) |
| | | |
| | | return pooled_stats |
| | | |
| | | |
| | | class SERes2NetBlock(nn.Module): |
| | | """An implementation of building block in ECAPA-TDNN, i.e., |
| | | TDNN-Res2Net-TDNN-SEBlock. |
| | | |
| | | Arguments |
| | | ---------- |
| | | out_channels: int |
| | | The number of output channels. |
| | | res2net_scale: int |
| | | The scale of the Res2Net block. |
| | | kernel_size: int |
| | | The kernel size of the TDNN blocks. |
| | | dilation: int |
| | | The dilation of the Res2Net block. |
| | | activation : torch class |
| | | A class for constructing the activation layers. |
| | | groups: int |
| | | Number of blocked connections from input channels to output channels. |
| | | |
| | | Example |
| | | ------- |
| | | >>> x = torch.rand(8, 120, 64).transpose(1, 2) |
| | | >>> conv = SERes2NetBlock(64, 64, res2net_scale=4) |
| | | >>> out = conv(x).transpose(1, 2) |
| | | >>> out.shape |
| | | torch.Size([8, 120, 64]) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | in_channels, |
| | | out_channels, |
| | | res2net_scale=8, |
| | | se_channels=128, |
| | | kernel_size=1, |
| | | dilation=1, |
| | | activation=torch.nn.ReLU, |
| | | groups=1, |
| | | ): |
| | | super().__init__() |
| | | self.out_channels = out_channels |
| | | self.tdnn1 = TDNNBlock( |
| | | in_channels, |
| | | out_channels, |
| | | kernel_size=1, |
| | | dilation=1, |
| | | activation=activation, |
| | | groups=groups, |
| | | ) |
| | | self.res2net_block = Res2NetBlock( |
| | | out_channels, out_channels, res2net_scale, kernel_size, dilation |
| | | ) |
| | | self.tdnn2 = TDNNBlock( |
| | | out_channels, |
| | | out_channels, |
| | | kernel_size=1, |
| | | dilation=1, |
| | | activation=activation, |
| | | groups=groups, |
| | | ) |
| | | self.se_block = SEBlock(out_channels, se_channels, out_channels) |
| | | |
| | | self.shortcut = None |
| | | if in_channels != out_channels: |
| | | self.shortcut = Conv1d( |
| | | in_channels=in_channels, |
| | | out_channels=out_channels, |
| | | kernel_size=1, |
| | | ) |
| | | |
| | | def forward(self, x, lengths=None): |
| | | residual = x |
| | | if self.shortcut: |
| | | residual = self.shortcut(x) |
| | | |
| | | x = self.tdnn1(x) |
| | | x = self.res2net_block(x) |
| | | x = self.tdnn2(x) |
| | | x = self.se_block(x, lengths) |
| | | |
| | | return x + residual |
| | | |
| | | |
| | | class ECAPA_TDNN(torch.nn.Module): |
| | | """An implementation of the speaker embedding model in a paper. |
| | | "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in |
| | | TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). |
| | | |
| | | Arguments |
| | | --------- |
| | | activation : torch class |
| | | A class for constructing the activation layers. |
| | | channels : list of ints |
| | | Output channels for TDNN/SERes2Net layer. |
| | | kernel_sizes : list of ints |
| | | List of kernel sizes for each layer. |
| | | dilations : list of ints |
| | | List of dilations for kernels in each layer. |
| | | lin_neurons : int |
| | | Number of neurons in linear layers. |
| | | groups : list of ints |
| | | List of groups for kernels in each layer. |
| | | |
| | | Example |
| | | ------- |
| | | >>> input_feats = torch.rand([5, 120, 80]) |
| | | >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192) |
| | | >>> outputs = compute_embedding(input_feats) |
| | | >>> outputs.shape |
| | | torch.Size([5, 1, 192]) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size, |
| | | lin_neurons=192, |
| | | activation=torch.nn.ReLU, |
| | | channels=[512, 512, 512, 512, 1536], |
| | | kernel_sizes=[5, 3, 3, 3, 1], |
| | | dilations=[1, 2, 3, 4, 1], |
| | | attention_channels=128, |
| | | res2net_scale=8, |
| | | se_channels=128, |
| | | global_context=True, |
| | | groups=[1, 1, 1, 1, 1], |
| | | window_size=20, |
| | | window_shift=1, |
| | | ): |
| | | |
| | | super().__init__() |
| | | assert len(channels) == len(kernel_sizes) |
| | | assert len(channels) == len(dilations) |
| | | self.channels = channels |
| | | self.blocks = nn.ModuleList() |
| | | self.window_size = window_size |
| | | self.window_shift = window_shift |
| | | |
| | | # The initial TDNN layer |
| | | self.blocks.append( |
| | | TDNNBlock( |
| | | input_size, |
| | | channels[0], |
| | | kernel_sizes[0], |
| | | dilations[0], |
| | | activation, |
| | | groups[0], |
| | | ) |
| | | ) |
| | | |
| | | # SE-Res2Net layers |
| | | for i in range(1, len(channels) - 1): |
| | | self.blocks.append( |
| | | SERes2NetBlock( |
| | | channels[i - 1], |
| | | channels[i], |
| | | res2net_scale=res2net_scale, |
| | | se_channels=se_channels, |
| | | kernel_size=kernel_sizes[i], |
| | | dilation=dilations[i], |
| | | activation=activation, |
| | | groups=groups[i], |
| | | ) |
| | | ) |
| | | |
| | | # Multi-layer feature aggregation |
| | | self.mfa = TDNNBlock( |
| | | channels[-1], |
| | | channels[-1], |
| | | kernel_sizes[-1], |
| | | dilations[-1], |
| | | activation, |
| | | groups=groups[-1], |
| | | ) |
| | | |
| | | # Attentive Statistical Pooling |
| | | self.asp = AttentiveStatisticsPooling( |
| | | channels[-1], |
| | | attention_channels=attention_channels, |
| | | global_context=global_context, |
| | | ) |
| | | self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2) |
| | | |
| | | # Final linear transformation |
| | | self.fc = Conv1d( |
| | | in_channels=channels[-1] * 2, |
| | | out_channels=lin_neurons, |
| | | kernel_size=1, |
| | | ) |
| | | |
| | | def windowed_pooling(self, x, lengths=None): |
| | | # x: Batch, Channel, Time |
| | | tt = x.shape[2] |
| | | num_chunk = int(math.ceil(tt / self.window_shift)) |
| | | pad = self.window_size // 2 |
| | | x = F.pad(x, (pad, pad, 0, 0), "reflect") |
| | | stat_list = [] |
| | | |
| | | for i in range(num_chunk): |
| | | # B x C |
| | | st, ed = i * self.window_shift, i * self.window_shift + self.window_size |
| | | x = self.asp(x[:, :, st: ed], |
| | | lengths=torch.clamp(lengths - i, 0, self.window_size) |
| | | if lengths is not None else None) |
| | | x = self.asp_bn(x) |
| | | x = self.fc(x) |
| | | stat_list.append(x) |
| | | |
| | | return torch.cat(stat_list, dim=2) |
| | | |
| | | def forward(self, x, lengths=None): |
| | | """Returns the embedding vector. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Tensor of shape (batch, time, channel). |
| | | lengths: torch.Tensor |
| | | Tensor of shape (batch, ) |
| | | """ |
| | | # Minimize transpose for efficiency |
| | | x = x.transpose(1, 2) |
| | | |
| | | xl = [] |
| | | for layer in self.blocks: |
| | | try: |
| | | x = layer(x, lengths=lengths) |
| | | except TypeError: |
| | | x = layer(x) |
| | | xl.append(x) |
| | | |
| | | # Multi-layer feature aggregation |
| | | x = torch.cat(xl[1:], dim=1) |
| | | x = self.mfa(x) |
| | | |
| | | if self.window_size is None: |
| | | # Attentive Statistical Pooling |
| | | x = self.asp(x, lengths=lengths) |
| | | x = self.asp_bn(x) |
| | | # Final linear transformation |
| | | x = self.fc(x) |
| | | # x = x.transpose(1, 2) |
| | | x = x.squeeze(2) # -> B, C |
| | | else: |
| | | x = self.windowed_pooling(x, lengths) |
| | | x = x.transpose(1, 2) # -> B, T, C |
| | | return x |
| | |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | | from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder |
| | | from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder |
| | | from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder |
| | |
| | | resnet34=ResNet34Diar, |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ecapa_tdnn=ECAPA_TDNN, |
| | | ), |
| | | type_check=AbsEncoder, |
| | | type_check=torch.nn.Module, |
| | | default="resnet34", |
| | | ) |
| | | speaker_encoder_choices = ClassChoices( |
| | |
| | | specaug_choices, |
| | | # --normalize and --normalize_conf |
| | | normalize_choices, |
| | | # --label_aggregator and --label_aggregator_conf |
| | | label_aggregator_choices, |
| | | # --model and --model_conf |
| | | model_choices, |
| | | # --encoder and --encoder_conf |
| | |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | if not inference: |
| | | retval = ("speech", "profile", "label") |
| | | retval = ("speech", "profile", "binary_labels") |
| | | else: |
| | | # Recognition mode |
| | | retval = ("speech", "profile") |