From 8cc5bbf99a59694228aafcbe8712e09b9a4cb26b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 27 二月 2023 17:01:48 +0800
Subject: [PATCH] Merge pull request #159 from alibaba-damo-academy/dev_dzh

---
 egs/mars/sd/path.sh                                                        |    5 
 funasr/models/e2e_diar_sond.py                                             |  219 +++-
 egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py        |   60 +
 egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py   |   53 +
 funasr/bin/sond_inference.py                                               |   17 
 egs/mars/sd/local_run.sh                                                   |  171 ++++
 funasr/bin/diar_train.py                                                   |   46 +
 funasr/tasks/diar.py                                                       |    8 
 funasr/bin/diar_inference_launch.py                                        |   10 
 egs/mars/sd/scripts/dump_rttm_to_labels.py                                 |  140 +++
 funasr/losses/label_smoothing_loss.py                                      |   18 
 egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml |  121 ++
 egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py    |  110 ++
 egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py       |   73 +
 egs/mars/sd/scripts/calculate_shapes.py                                    |   45 +
 funasr/layers/label_aggregation.py                                         |    2 
 egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py                      |  115 ++
 egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py       |   57 +
 egs/mars/sd/scripts/simu_chunk_with_labels.py                              |  261 ++++++
 funasr/models/encoder/ecapa_tdnn_encoder.py                                |  686 ++++++++++++++++
 egs/alimeeting/diarization/sond/unit_test_modelscope.py                    |   92 ++
 egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py       |  138 +++
 22 files changed, 2,367 insertions(+), 80 deletions(-)

diff --git a/egs/alimeeting/diarization/sond/unit_test_modelscope.py b/egs/alimeeting/diarization/sond/unit_test_modelscope.py
new file mode 100644
index 0000000..ea543e1
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/unit_test_modelscope.py
@@ -0,0 +1,92 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+import os
+
+
+def test_wav_cpu_infer():
+    output_dir = "./outputs"
+    data_path_and_name_and_type = [
+        "data/unit_test/test_wav.scp,speech,sound",
+        "data/unit_test/test_profile.scp,profile,kaldi_ark",
+    ]
+    diar_pipeline = pipeline(
+        task=Tasks.speaker_diarization,
+        model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
+        mode="sond",
+        output_dir=output_dir,
+        num_workers=0,
+        log_level="WARNING",
+    )
+    results = diar_pipeline(data_path_and_name_and_type)
+    print(results)
+
+
+def test_wav_gpu_infer():
+    output_dir = "./outputs"
+    data_path_and_name_and_type = [
+        "data/unit_test/test_wav.scp,speech,sound",
+        "data/unit_test/test_profile.scp,profile,kaldi_ark",
+    ]
+    diar_pipeline = pipeline(
+        task=Tasks.speaker_diarization,
+        model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
+        mode="sond",
+        output_dir=output_dir,
+        num_workers=0,
+        log_level="WARNING",
+    )
+    results = diar_pipeline(data_path_and_name_and_type)
+    print(results)
+
+
+def test_without_profile_gpu_infer():
+    raw_inputs = [
+        "data/unit_test/raw_inputs/record.wav",
+        "data/unit_test/raw_inputs/spk1.wav",
+        "data/unit_test/raw_inputs/spk2.wav",
+        "data/unit_test/raw_inputs/spk3.wav",
+        "data/unit_test/raw_inputs/spk4.wav"
+    ]
+    diar_pipeline = pipeline(
+        task=Tasks.speaker_diarization,
+        model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
+        mode="sond_demo",
+        sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
+        sv_model_revision="master",
+        num_workers=0,
+        log_level="WARNING",
+        param_dict={},
+    )
+    results = diar_pipeline(raw_inputs)
+    print(results)
+
+
+def test_url_without_profile_gpu_infer():
+    raw_inputs = [
+        "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav",
+        "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav",
+        "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk2.wav",
+        "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk3.wav",
+        "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk4.wav",
+    ]
+    diar_pipeline = pipeline(
+        task=Tasks.speaker_diarization,
+        model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
+        mode="sond_demo",
+        sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
+        sv_model_revision="master",
+        num_workers=0,
+        log_level="WARNING",
+        param_dict={},
+    )
+    results = diar_pipeline(raw_inputs)
+    print(results)
+
+
+if __name__ == '__main__':
+    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+    test_wav_cpu_infer()
+    test_wav_gpu_infer()
+    test_without_profile_gpu_infer()
+    test_url_without_profile_gpu_infer()
diff --git a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
new file mode 100644
index 0000000..459a741
--- /dev/null
+++ b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
@@ -0,0 +1,121 @@
+model: sond
+model_conf:
+    lsm_weight: 0.0
+    length_normalized_loss: true
+    max_spk_num: 16
+
+# speech encoder
+encoder: ecapa_tdnn
+encoder_conf:
+    # pass by model, equal to feature dim
+    # input_size: 80
+    pool_size: 20
+    stride: 1
+speaker_encoder: conv
+speaker_encoder_conf:
+    input_units: 256
+    num_layers: 3
+    num_units: 256
+    kernel_size: 1
+    dropout_rate: 0.0
+    position_encoder: null
+    out_units: 256
+    out_norm: false
+    auxiliary_states: false
+    tf2torch_tensor_name_prefix_torch: speaker_encoder
+    tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
+ci_scorer: dot
+ci_scorer_conf: {}
+cd_scorer: san
+cd_scorer_conf:
+    input_size: 512
+    output_size: 512
+    out_units: 1
+    attention_heads: 4
+    linear_units: 1024
+    num_blocks: 4
+    dropout_rate: 0.0
+    positional_dropout_rate: 0.0
+    attention_dropout_rate: 0.0
+    # use string "null" to remove input layer
+    input_layer: "null"
+    pos_enc_class: null
+    normalize_before: true
+    tf2torch_tensor_name_prefix_torch: cd_scorer
+    tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
+# post net
+decoder: fsmn
+decoder_conf:
+    in_units: 32
+    out_units: 2517
+    filter_size: 31
+    fsmn_num_layers: 6
+    dnn_num_layers: 1
+    num_memory_units: 512
+    ffn_inner_dim: 512
+    dropout_rate: 0.0
+    tf2torch_tensor_name_prefix_torch: decoder
+    tf2torch_tensor_name_prefix_tf: EAND/post_net
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: povey
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    filter_length_min: -1
+    filter_length_max: -1
+    lfr_m: 1
+    lfr_n: 1
+    dither: 0.0
+    snip_edges: false
+
+# minibatch related
+batch_type: length
+# 16s * 16k * 16 samples
+batch_bins: 4096000
+num_workers: 8
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 50
+val_scheduler_criterion:
+    - valid
+    - acc
+best_model_criterion:
+-   - valid
+    - der
+    - min
+-   - valid
+    - forward_steps
+    - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 10000
+
+# without spec aug
+specaug: null
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
+log_interval: 50
+# without normalize
+normalize: None
diff --git a/egs/mars/sd/local_run.sh b/egs/mars/sd/local_run.sh
new file mode 100755
index 0000000..3b319f4
--- /dev/null
+++ b/egs/mars/sd/local_run.sh
@@ -0,0 +1,171 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="6,7"
+gpu_num=2
+count=1
+gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="." #feature output dictionary
+exp_dir="."
+lang=zh
+dumpdir=dump/raw
+feats_type=raw
+token_type=char
+scp=wav.scp
+type=kaldi_ark
+stage=3
+stop_stage=4
+
+# feature configuration
+feats_dim=
+sample_frequency=16000
+nj=32
+speed_perturb=
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_asr_conformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_asr_transformer.yaml
+inference_asr_model=valid.acc.ave_10best.pth
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+    inference_nj=$[${ngpu}*${njob}]
+    _ngpu=1
+else
+    inference_nj=$njob
+    _ngpu=0
+fi
+
+feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
+feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
+feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
+
+# Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: Training"
+    mkdir -p ${exp_dir}/exp/${model_dir}
+    mkdir -p ${exp_dir}/exp/${model_dir}/log
+    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi 
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $gpu_num; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            asr_train.py \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --token_type char \
+                --token_list $token_list \
+                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
+                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
+                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
+                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
+                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
+                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
+                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
+                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
+                --resume true \
+                --output_dir ${exp_dir}/exp/${model_dir} \
+                --config $asr_config \
+                --input_size $feats_dim \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --multiprocessing_distributed true \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+        } &
+        done
+        wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    echo "stage 4: Inference"
+    for dset in ${test_sets}; do
+        asr_exp=${exp_dir}/exp/${model_dir}
+        inference_tag="$(basename "${inference_config}" .yaml)"
+        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+        _logdir="${_dir}/logdir"
+        if [ -d ${_dir} ]; then
+            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+            exit 0
+        fi
+        mkdir -p "${_logdir}"
+        _data="${feats_dir}/${dumpdir}/${dset}"
+        key_file=${_data}/${scp}
+        num_scp_file="$(<${key_file} wc -l)"
+        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+        split_scps=
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/keys.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
+            python -m funasr.bin.asr_inference_launch \
+                --batch_size 1 \
+                --ngpu "${_ngpu}" \
+                --njob ${njob} \
+                --gpuid_list ${gpuid_list} \
+                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --key_file "${_logdir}"/keys.JOB.scp \
+                --asr_train_config "${asr_exp}"/config.yaml \
+                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+                --output_dir "${_logdir}"/output.JOB \
+                --mode asr \
+                ${_opts}
+
+        for f in token token_int score text; do
+            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | sort -k1 >"${_dir}/${f}"
+            fi
+        done
+        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+        python utils/proce_text.py ${_data}/text ${_data}/text.proc
+        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+        cat ${_dir}/text.cer.txt
+    done
+fi
+
diff --git a/egs/mars/sd/path.sh b/egs/mars/sd/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/egs/mars/sd/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/mars/sd/scripts/calculate_shapes.py b/egs/mars/sd/scripts/calculate_shapes.py
new file mode 100644
index 0000000..b207f2d
--- /dev/null
+++ b/egs/mars/sd/scripts/calculate_shapes.py
@@ -0,0 +1,45 @@
+import logging
+import numpy as np
+import soundfile
+import kaldiio
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import argparse
+from collections import OrderedDict
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser: argparse.ArgumentParser):
+        parser.add_argument("--input_scp", type=str, required=True)
+        parser.add_argument("--out_path")
+        args = parser.parse_args()
+
+        if not os.path.exists(os.path.dirname(args.out_path)):
+            os.makedirs(os.path.dirname(args.out_path))
+
+        task_list = load_scp_as_list(args.input_scp)
+        return task_list, None, args
+
+    def post(self, result_list, args):
+        fd = open(args.out_path, "wt", encoding="utf-8")
+        for results in result_list:
+            for uttid, shape in results:
+                fd.write("{} {}\n".format(uttid, ",".join(shape)))
+        fd.close()
+
+
+def process(task_args):
+    task_idx, task_list, _, args = task_args
+    rst = []
+    for uttid, file_path in task_list:
+        data = kaldiio.load_mat(file_path)
+        shape = [str(x) for x in data.shape]
+        rst.append((uttid, shape))
+    return rst
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py
new file mode 100644
index 0000000..ec1c765
--- /dev/null
+++ b/egs/mars/sd/scripts/dump_rttm_to_labels.py
@@ -0,0 +1,140 @@
+import logging
+import numpy as np
+import soundfile
+import kaldiio
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import argparse
+from collections import OrderedDict
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser: argparse.ArgumentParser):
+        parser.add_argument("--rttm_list", type=str, required=True)
+        parser.add_argument("--wav_scp_list", type=str, required=True)
+        parser.add_argument("--out_dir", type=str, required=True)
+        parser.add_argument("--n_spk", type=int, default=8)
+        parser.add_argument("--remove_sil", default=False, action="store_true")
+        parser.add_argument("--max_overlap", default=0, type=int)
+        parser.add_argument("--frame_shift", type=float, default=0.01)
+        args = parser.parse_args()
+
+        rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
+        meeting2rttm = OrderedDict()
+        for rttm_path in rttm_list:
+            meeting2rttm.update(self.load_rttm(rttm_path))
+
+        wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
+        meeting_scp = OrderedDict()
+        for scp_path in wav_scp_list:
+            meeting_scp.update(load_scp_as_dict(scp_path))
+
+        if len(meeting_scp) != len(meeting2rttm):
+            logging.warning("Number of wav and rttm mismatch {} != {}".format(
+                len(meeting_scp), len(meeting2rttm)))
+            common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
+            logging.warning("Keep {} records.".format(len(common_keys)))
+            new_meeting_scp = OrderedDict()
+            rm_keys = []
+            for key in meeting_scp:
+                if key not in common_keys:
+                    rm_keys.append(key)
+                else:
+                    new_meeting_scp[key] = meeting_scp[key]
+            logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
+
+            new_meeting2rttm = OrderedDict()
+            rm_keys = []
+            for key in meeting2rttm:
+                if key not in common_keys:
+                    rm_keys.append(key)
+                else:
+                    new_meeting2rttm[key] = meeting2rttm[key]
+            logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
+            meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
+        return task_list, None, args
+
+    @staticmethod
+    def load_rttm(rttm_path):
+        meeting2rttm = OrderedDict()
+        for one_line in open(rttm_path, "rt", encoding="utf-8"):
+            mid = one_line.strip().split(" ")[1]
+            if mid not in meeting2rttm:
+                meeting2rttm[mid] = []
+            meeting2rttm[mid].append(one_line.strip())
+
+        return meeting2rttm
+
+    def post(self, results_list, args):
+        pass
+
+
+def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
+                sr=None, frame_shift=0.01):
+    frame_shift = int(frame_shift * sr)
+    num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
+    multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
+    for _, st, dur, spk in spk_turns:
+        idx = spk_list.index(spk)
+
+        st, dur = int(st * sr), int(dur * sr)
+        frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
+        frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
+        multi_label[idx, frame_st:frame_ed] = 1
+
+    if remove_sil:
+        speech_count = np.sum(multi_label, axis=0)
+        idx = np.nonzero(speech_count)[0]
+        multi_label = multi_label[:, idx]
+
+    if max_overlap > 0:
+        speech_count = np.sum(multi_label, axis=0)
+        idx = np.nonzero(speech_count <= max_overlap)[0]
+        multi_label = multi_label[:, idx]
+
+    label = multi_label.T
+    return label  # (T, N)
+
+
+def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
+                 sr=16000, frame_shift=0.01):
+    wav, sr = soundfile.read(wav_path)
+    wav_len = len(wav)
+    spk_turns = []
+    spk_list = []
+    for one_line in rttms:
+        parts = one_line.strip().split(" ")
+        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
+        if spk not in spk_list:
+            spk_list.append(spk)
+        spk_turns.append((mid, st, dur, spk))
+    labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
+    return labels, spk_list
+
+
+def process(task_args):
+    task_idx, task_list, _, args = task_args
+    spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
+                           "wt", encoding="utf-8")
+    out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
+    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+    for mid, wav_path, rttms in task_list:
+        meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
+                                                args.sr, args.frame_shift)
+        label_writer(mid, meeting_labels)
+        spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
+
+    spk_list_writer.close()
+    label_writer.close()
+    return None
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
new file mode 100644
index 0000000..cd1ec7b
--- /dev/null
+++ b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
@@ -0,0 +1,115 @@
+import numpy as np
+import os
+import argparse
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import soundfile as sf
+from tqdm import tqdm
+
+
+class MyRunner(MultiProcessRunnerV3):
+    def prepare(self, parser):
+        assert isinstance(parser, argparse.ArgumentParser)
+        parser.add_argument("wav_scp", type=str)
+        parser.add_argument("rttm", type=str)
+        parser.add_argument("out_dir", type=str)
+        parser.add_argument("--min_dur", type=float, default=2.0)
+        parser.add_argument("--max_spk_num", type=int, default=4)
+        args = parser.parse_args()
+
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        wav_scp = load_scp_as_list(args.wav_scp)
+        meeting2rttms = {}
+        for one_line in open(args.rttm, "rt"):
+            parts = [x for x in one_line.strip().split(" ") if x != ""]
+            mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
+            if mid not in meeting2rttms:
+                meeting2rttms[mid] = []
+            meeting2rttms[mid].append(one_line)
+
+        task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
+        return task_list, None, args
+
+    def post(self, result_list, args):
+        count = [0, 0]
+        for result in result_list:
+            count[0] += result[0]
+            count[1] += result[1]
+        print("Found {} speakers, extracted {}.".format(count[1], count[0]))
+
+
+# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
+def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
+    labels = np.zeros([max_spk_num, length], int)
+    spk_list = []
+    for one_line in rttms:
+        parts = [x for x in one_line.strip().split(" ") if x != ""]
+        mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
+        spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
+        if spk_name.isdigit():
+            spk_name = "{}_S{:03d}".format(mid, int(spk_name))
+        else:
+            spk_name = "{}_{}".format(mid, spk_name)
+        if spk_name not in spk_list:
+            spk_list.append(spk_name)
+        st, dur = int(st*sr), int(dur*sr)
+        idx = spk_list.index(spk_name)
+        labels[idx, st:st+dur] = 1
+    return labels, spk_list
+
+
+def get_nonoverlap_turns(multi_label, spk_list):
+    turns = []
+    label = np.sum(multi_label, axis=0) == 1
+    spk, in_turn, st = None, False, 0
+    for i in range(len(label)):
+        if not in_turn and label[i]:
+            st, in_turn = i, True
+            spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
+        if in_turn:
+            if not label[i]:
+                in_turn = False
+                turns.append([st, i, spk])
+            elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
+                turns.append([st, i, spk])
+                st, in_turn = i, True
+                spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
+    if in_turn:
+        turns.append([st, len(label), spk])
+    return turns
+
+
+def process(task_args):
+    task_id, task_list, _, args = task_args
+    spk_count = [0, 0]
+    for mid, wav_path, rttms in task_list:
+        wav, sr = sf.read(wav_path, dtype="int16")
+        assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
+        multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
+        turns = get_nonoverlap_turns(multi_label, spk_list)
+        extracted_spk = []
+        count = 1
+        for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
+            if (ed - st) >= args.min_dur * args.sr:
+                seg = wav[st: ed]
+                save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
+                if not os.path.exists(os.path.dirname(save_path)):
+                    os.makedirs(os.path.dirname(save_path))
+                sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
+                count += 1
+                if spk not in extracted_spk:
+                    extracted_spk.append(spk)
+        if len(extracted_spk) != len(spk_list):
+            print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
+                mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
+            ))
+        spk_count[0] += len(extracted_spk)
+        spk_count[1] += len(spk_list)
+    return spk_count
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py b/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
new file mode 100644
index 0000000..e579f51
--- /dev/null
+++ b/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
@@ -0,0 +1,73 @@
+import numpy as np
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import librosa
+import argparse
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser):
+        parser.add_argument("dir", type=str)
+        parser.add_argument("out_dir", type=str)
+        parser.add_argument("--n_spk", type=int, default=4)
+        parser.add_argument("--remove_sil", default=False, action="store_true")
+        args = parser.parse_args()
+
+        meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp"))
+        rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
+
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp]
+        return task_list, None, args
+
+    def post(self, results_list, args):
+        pass
+
+
+def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000):
+    multi_label = np.zeros([n_spk, length], dtype=int)
+    for _, st, dur, spk in spk_turns:
+        st, dur = int(st * sr), int(dur * sr)
+        idx = spk_list.index(spk)
+        multi_label[idx, st:st+dur] = 1
+    if not remove_sil:
+        return multi_label.T
+
+    speech_count = np.sum(multi_label, axis=0)
+    idx = np.nonzero(speech_count)[0]
+    label = multi_label[:, idx].T
+    return label  # (T, N)
+
+
+def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000):
+    wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
+    spk_turns = []
+    spk_list = []
+    for one_line in open(rttm_path, "rt"):
+        parts = one_line.strip().split(" ")
+        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
+        spk = "{}_S{:03d}".format(mid, spk)
+        if spk not in spk_list:
+            spk_list.append(spk)
+        spk_turns.append((mid, st, dur, spk))
+    labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil)
+    return labels
+
+
+def process(task_args):
+    _, task_list, _, args = task_args
+    for mid, wav_path, rttm_path in task_list:
+        meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil)
+        save_path = os.path.join(args.out_dir, "{}.lbl".format(mid))
+        np.save(save_path, meeting_labels.astype(bool))
+        print(mid)
+    return None
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py b/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
new file mode 100644
index 0000000..11bc395
--- /dev/null
+++ b/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
@@ -0,0 +1,53 @@
+import numpy as np
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import librosa
+import soundfile as sf
+from tqdm import tqdm
+import argparse
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser):
+        parser.add_argument("wav_scp", type=str)
+        parser.add_argument("out_dir", type=str)
+        parser.add_argument("--chunk_dur", type=float, default=16)
+        parser.add_argument("--shift_dur", type=float, default=4)
+        args = parser.parse_args()
+
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        wav_scp = load_scp_as_list(args.wav_scp)
+        return wav_scp, None, args
+
+    def post(self, results_list, args):
+        pass
+
+
+def process(task_args):
+    _, task_list, _, args = task_args
+    chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr)
+    for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
+        if not os.path.exists(os.path.join(args.out_dir, mid)):
+            os.makedirs(os.path.join(args.out_dir, mid))
+
+        wav = librosa.load(wav_path, args.sr, True)[0] * 32767
+        n_chunk = (len(wav) - chunk_len) // shift_len + 1
+        if (len(wav) - chunk_len) % shift_len > 0:
+            n_chunk += 1
+        for i in range(n_chunk):
+            seg = wav[i*shift_len: i*shift_len + chunk_len]
+            st = int(float(i*shift_len)/args.sr * 100)
+            dur = int(float(len(seg))/args.sr * 100)
+            file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur)
+            save_path = os.path.join(args.out_dir, mid, file_name)
+            sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
+    return None
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py b/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
new file mode 100644
index 0000000..011bd7c
--- /dev/null
+++ b/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
@@ -0,0 +1,57 @@
+import numpy as np
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import argparse
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser):
+        parser.add_argument("--rttm_scp", type=str)
+        parser.add_argument("--seg_file", type=str)
+        args = parser.parse_args()
+
+        if not os.path.exists(os.path.dirname(args.seg_file)):
+            os.makedirs(os.path.dirname(args.seg_file))
+
+        task_list = load_scp_as_list(args.rttm_scp)
+        return task_list, None, args
+
+    def post(self, results_list, args):
+        with open(args.seg_file, "wt", encoding="utf-8") as fd:
+            for results in results_list:
+                fd.writelines(results)
+
+
+def process(task_args):
+    _, task_list, _, args = task_args
+    outputs = []
+    for mid, rttm_path in task_list:
+        spk_turns = []
+        length = 0
+        for one_line in open(rttm_path, 'rt', encoding="utf-8"):
+            parts = one_line.strip().split(" ")
+            _, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
+            st, ed = int(st*100), int((st + dur)*100)
+            length = ed if ed > length else length
+            spk_turns.append([mid, st, ed, spk_name])
+        is_sph = np.zeros((length+1, ), dtype=bool)
+        for _, st, ed, _ in spk_turns:
+            is_sph[st:ed] = True
+
+        st, in_speech = 0, False
+        for i in range(length+1):
+            if not in_speech and is_sph[i]:
+                st, in_speech = i, True
+            if in_speech and not is_sph[i]:
+                in_speech = False
+                outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
+                    mid, st, i, mid, float(st)/100, float(i)/100
+                ))
+    return outputs
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py b/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
new file mode 100644
index 0000000..a2bcd39
--- /dev/null
+++ b/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
@@ -0,0 +1,138 @@
+import soundfile
+import kaldiio
+from tqdm import tqdm
+import json
+import os
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import numpy as np
+import argparse
+import random
+
+short_spk_list = []
+def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
+    all_utts = spk2utt[spk]
+    idx_list = list(range(len(all_utts)))
+    random.shuffle(idx_list)
+    count = 0
+    utt_list = []
+    for i in idx_list:
+        utt_id = all_utts[i]
+        utt_list.append(utt_id)
+        count += int(utt2frames[utt_id])
+        if count >= total_len:
+            break
+    if count < 300 and spk not in short_spk_list:
+        print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
+        short_spk_list.append(spk)
+
+    ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list]
+    ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list]
+    ivc = np.concatenate(ivc_list, axis=0)
+    ivc = np.mean(ivc, axis=0, keepdims=False)
+    return ivc
+
+
+def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
+    out_prefix = args.out
+
+    ivc_dim = 192
+    win_len, win_shift = 400, 160
+    label_weights = 2 ** np.array(list(range(args.n_spk)))
+    wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
+    ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
+    label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
+
+
+    frames_list = []
+    chunk_size = int(args.chunk_size * args.sr)
+    chunk_shift = int(args.chunk_shift * args.sr)
+    for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
+        meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
+        num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
+        meeting_labels = np.load(labels_scp[mid])
+        for i in range(num_chunk):
+            st, ed = i*chunk_shift, i*chunk_shift+chunk_size
+            seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
+            wav_writer(seg_id, meeting_wav[st: ed])
+
+            xvec_list = []
+            for spk in meeting2spk_list[mid]:
+                spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)
+                xvec_list.append(spk_xvec)
+            for _ in range(args.n_spk - len(xvec_list)):
+                xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
+            xvec = np.row_stack(xvec_list)
+            ivc_writer(seg_id, xvec)
+
+            wav_label = meeting_labels[st:ed, :]
+            frame_num = (ed-st) // win_shift
+            # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
+            feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32)
+            for i in range(frame_num):
+                frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
+                feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32)
+            label_writer(seg_id, feat_label)
+
+            frames_list.append((mid, feat_label.shape[0]))
+    return frames_list
+
+
+def calc_spk_list(rttm_path):
+    spk_list = []
+    for one_line in open(rttm_path, "rt"):
+        parts = one_line.strip().split(" ")
+        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
+        spk = "{}_S{:03d}".format(mid, spk)
+        if spk not in spk_list:
+            spk_list.append(spk)
+
+    return spk_list
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dir", required=True, type=str, default=None,
+                        help="feats.scp")
+    parser.add_argument("--out", required=True, type=str, default=None,
+                        help="The prefix of dumpped files.")
+    parser.add_argument("--n_spk", type=int, default=4)
+    parser.add_argument("--use_lfr", default=False, action="store_true")
+    parser.add_argument("--no_pbar", default=False, action="store_true")
+    parser.add_argument("--sr", type=int, default=16000)
+    parser.add_argument("--chunk_size", type=int, default=16)
+    parser.add_argument("--chunk_shift", type=int, default=4)
+    args = parser.parse_args()
+
+    if not os.path.exists(os.path.dirname(args.out)):
+        os.makedirs(os.path.dirname(args.out))
+
+    meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
+    labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
+    rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
+    utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
+    utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
+    utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
+    utt2frames = {}
+    for uttid, wav_path in utt2wav.items():
+        wav, sr = soundfile.read(wav_path, dtype="int16")
+        utt2frames[uttid] = int(len(wav) / sr * 100)
+
+    meeting2spk_list = {}
+    for mid, rttm_path in rttm_scp:
+        meeting2spk_list[mid] = calc_spk_list(rttm_path)
+
+    spk2utt = {}
+    for utt, spk in utt2spk.items():
+        if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
+            if spk not in spk2utt:
+                spk2utt[spk] = []
+            spk2utt[spk].append(utt)
+
+    # random.shuffle(feat_scp)
+    meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
+    total_frames = sum([x[1] for x in meeting_lens])
+    print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
new file mode 100644
index 0000000..1d6f53e
--- /dev/null
+++ b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
@@ -0,0 +1,110 @@
+from __future__ import print_function
+import numpy as np
+import os
+import sys
+import argparse
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import librosa
+import soundfile as sf
+from copy import deepcopy
+import json
+from tqdm import tqdm
+
+
+class MyRunner(MultiProcessRunnerV3):
+    def prepare(self, parser):
+        assert isinstance(parser, argparse.ArgumentParser)
+        parser.add_argument("wav_scp", type=str)
+        parser.add_argument("rttm_scp", type=str)
+        parser.add_argument("out_dir", type=str)
+        parser.add_argument("--min_dur", type=float, default=2.0)
+        parser.add_argument("--max_spk_num", type=int, default=4)
+        args = parser.parse_args()
+
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        wav_scp = load_scp_as_list(args.wav_scp)
+        rttm_scp = load_scp_as_dict(args.rttm_scp)
+        task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp]
+        return task_list, None, args
+
+    def post(self, result_list, args):
+        count = [0, 0]
+        for result in result_list:
+            count[0] += result[0]
+            count[1] += result[1]
+        print("Found {} speakers, extracted {}.".format(count[1], count[0]))
+
+
+# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
+def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4):
+    labels = np.zeros([max_spk_num, length], int)
+    spk_list = []
+    for one_line in open(rttm_path, 'rt'):
+        parts = one_line.strip().split(" ")
+        mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
+        if spk_name.isdigit():
+            spk_name = "{}_S{:03d}".format(mid, int(spk_name))
+        if spk_name not in spk_list:
+            spk_list.append(spk_name)
+        st, dur = int(st*sr), int(dur*sr)
+        idx = spk_list.index(spk_name)
+        labels[idx, st:st+dur] = 1
+    return labels, spk_list
+
+
+def get_nonoverlap_turns(multi_label, spk_list):
+    turns = []
+    label = np.sum(multi_label, axis=0) == 1
+    spk, in_turn, st = None, False, 0
+    for i in range(len(label)):
+        if not in_turn and label[i]:
+            st, in_turn = i, True
+            spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
+        if in_turn:
+            if not label[i]:
+                in_turn = False
+                turns.append([st, i, spk])
+            elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
+                turns.append([st, i, spk])
+                st, in_turn = i, True
+                spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
+    if in_turn:
+        turns.append([st, len(label), spk])
+    return turns
+
+
+def process(task_args):
+    task_id, task_list, _, args = task_args
+    spk_count = [0, 0]
+    for mid, wav_path, rttm_path in task_list:
+        wav, sr = sf.read(wav_path, dtype="int16")
+        assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
+        multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
+        turns = get_nonoverlap_turns(multi_label, spk_list)
+        extracted_spk = []
+        count = 1
+        for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
+            if (ed - st) >= args.min_dur * args.sr:
+                seg = wav[st: ed]
+                save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
+                if not os.path.exists(os.path.dirname(save_path)):
+                    os.makedirs(os.path.dirname(save_path))
+                sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
+                count += 1
+                if spk not in extracted_spk:
+                    extracted_spk.append(spk)
+        if len(extracted_spk) != len(spk_list):
+            print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
+                mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
+            ))
+        spk_count[0] += len(extracted_spk)
+        spk_count[1] += len(spk_list)
+    return spk_count
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py b/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
new file mode 100644
index 0000000..8b3195f
--- /dev/null
+++ b/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
@@ -0,0 +1,60 @@
+import numpy as np
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import librosa
+import soundfile as sf
+import argparse
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser):
+        parser.add_argument("dir", type=str)
+        parser.add_argument("out_dir", type=str)
+        args = parser.parse_args()
+
+        meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp"))
+        vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8")
+        meeting2vad = {}
+        for one_line in vad_file:
+            uid, mid, st, ed = one_line.strip().split(" ")
+            st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
+            if mid not in meeting2vad:
+                meeting2vad[mid] = []
+            meeting2vad[mid].append((uid, st, ed))
+
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp]
+        return task_list, None, args
+
+    def post(self, results_list, args):
+        pass
+
+
+def process(task_args):
+    _, task_list, _, args = task_args
+    for mid, wav_path, vad_list in task_list:
+        wav = librosa.load(wav_path, args.sr, True)[0] * 32767
+        seg_list = []
+        pos_map = []
+        offset = 0
+        for uid, st, ed in vad_list:
+            seg_list.append(wav[st: ed])
+            pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
+            offset = offset + ed - st
+        out = np.concatenate(seg_list, axis=0)
+        save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
+        sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
+        map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
+        with open(map_path, "wt", encoding="utf-8") as fd:
+            fd.writelines(pos_map)
+        print(mid)
+    return None
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py
new file mode 100644
index 0000000..f61b808
--- /dev/null
+++ b/egs/mars/sd/scripts/simu_chunk_with_labels.py
@@ -0,0 +1,261 @@
+import logging
+import numpy as np
+import soundfile
+import kaldiio
+from funasr.utils.job_runner import MultiProcessRunnerV3
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+import os
+import argparse
+from collections import OrderedDict
+import random
+from typing import List, Dict
+from copy import deepcopy
+import json
+logging.basicConfig(
+    level="INFO",
+    format=f"[{os.uname()[1].split('.')[0]}]"
+           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+)
+
+
+class MyRunner(MultiProcessRunnerV3):
+
+    def prepare(self, parser: argparse.ArgumentParser):
+        parser.add_argument("--label_scp", type=str, required=True)
+        parser.add_argument("--wav_scp", type=str, required=True)
+        parser.add_argument("--utt2spk", type=str, required=True)
+        parser.add_argument("--spk2meeting", type=str, required=True)
+        parser.add_argument("--utt2xvec", type=str, required=True)
+        parser.add_argument("--out_dir", type=str, required=True)
+        parser.add_argument("--chunk_size", type=float, default=16)
+        parser.add_argument("--chunk_shift", type=float, default=4)
+        parser.add_argument("--frame_shift", type=float, default=0.01)
+        parser.add_argument("--embedding_dim", type=int, default=None)
+        parser.add_argument("--average_emb_num", type=int, default=0)
+        parser.add_argument("--subset", type=int, default=0)
+        parser.add_argument("--data_json", type=str, default=None)
+        parser.add_argument("--seed", type=int, default=1234)
+        parser.add_argument("--log_interval", type=int, default=100)
+        args = parser.parse_args()
+        random.seed(args.seed)
+        np.random.seed(args.seed)
+
+        logging.info("Loading data...")
+        if not os.path.exists(args.data_json):
+            label_list = load_scp_as_list(args.label_scp)
+            wav_scp = load_scp_as_dict(args.wav_scp)
+            utt2spk = load_scp_as_dict(args.utt2spk)
+            utt2xvec = load_scp_as_dict(args.utt2xvec)
+            spk2meeting = load_scp_as_dict(args.spk2meeting)
+
+            meeting2spks = OrderedDict()
+            for spk, meeting in spk2meeting.items():
+                if meeting not in meeting2spks:
+                    meeting2spks[meeting] = []
+                meeting2spks[meeting].append(spk)
+
+            spk2utts = OrderedDict()
+            for utt, spk in utt2spk.items():
+                if spk not in spk2utts:
+                    spk2utts[spk] = []
+                spk2utts[spk].append(utt)
+
+            os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
+            logging.info("Dump data...")
+            json.dump({
+                "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
+                "spk2utts": spk2utts, "meeting2spks": meeting2spks
+            }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
+        else:
+            data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
+            label_list = data_dict["label_list"]
+            wav_scp = data_dict["wav_scp"]
+            utt2xvec = data_dict["utt2xvec"]
+            spk2utts = data_dict["spk2utts"]
+            meeting2spks = data_dict["meeting2spks"]
+
+        if not os.path.exists(args.out_dir):
+            os.makedirs(args.out_dir)
+
+        args.chunk_size = int(args.chunk_size / args.frame_shift)
+        args.chunk_shift = int(args.chunk_shift / args.frame_shift)
+
+        if args.embedding_dim is None:
+            args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
+            logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
+
+        logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
+            len(wav_scp), len(spk2utts), len(meeting2spks)
+        ))
+        return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
+
+    def post(self, results_list, args):
+        logging.info("[main]: Got {} chunks.".format(sum(results_list)))
+
+
+def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
+    utt_list = spk2utts[spk]
+    wav_list = []
+    cur_length = 0
+    while cur_length < sample_length:
+        uttid = random.choice(utt_list)
+        wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
+        wav_list.append(wav)
+        cur_length += len(wav)
+    concat_wav = np.concatenate(wav_list, axis=0)
+    start = random.randint(0, len(concat_wav) - sample_length)
+    return concat_wav[start: start+sample_length]
+
+
+def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
+    # process for dummy speaker
+    if spk == "None":
+        return np.zeros((1, embedding_dim), dtype=np.float32)
+
+    # calculate averaged speaker embeddings
+    utt_list = spk2utts[spk]
+    if average_emb_num == 0 or average_emb_num > len(utt_list):
+        xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
+    else:
+        xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
+    xvec = np.concatenate(xvec_list, axis=0)
+    xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
+    xvec = np.mean(xvec, axis=0)
+
+    return xvec
+
+
+def simu_chunk(
+        frame_label: np.ndarray,
+        sample_label: np.ndarray,
+        wav_scp: Dict[str, str],
+        utt2xvec: Dict[str, str],
+        spk2utts: Dict[str, List[str]],
+        meeting2spks: Dict[str, List[str]],
+        all_speaker_list: List[str],
+        meeting_list: List[str],
+        embedding_dim: int,
+        average_emb_num: int,
+):
+    frame_length, max_spk_num = frame_label.shape
+    sample_length = sample_label.shape[0]
+    positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
+    pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
+
+    # get positive speakers
+    if len(pos_speaker_list) >= positive_speaker_num:
+        pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
+    else:
+        while len(pos_speaker_list) < positive_speaker_num:
+            _spk = random.choice(all_speaker_list)
+            if _spk not in pos_speaker_list:
+                pos_speaker_list.append(_spk)
+
+    # get negative speakers
+    negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
+    neg_speaker_list = []
+    while len(neg_speaker_list) < negative_speaker_num:
+        _spk = random.choice(all_speaker_list)
+        if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
+            neg_speaker_list.append(_spk)
+    neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
+
+    random.shuffle(pos_speaker_list)
+    random.shuffle(neg_speaker_list)
+    seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
+    this_spk_list = []
+    for idx, frame_num in enumerate(frame_label.sum(axis=0)):
+        if frame_num > 0:
+            spk = pos_speaker_list.pop(0)
+            this_spk_list.append(spk)
+            simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
+            seperated_wav[:, idx] = simu_spk_wav
+        else:
+            spk = neg_speaker_list.pop(0)
+            this_spk_list.append(spk)
+
+    # calculate mixed wav
+    mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
+
+    # shuffle the order of speakers
+    shuffle_idx = list(range(max_spk_num))
+    random.shuffle(shuffle_idx)
+    this_spk_list = [this_spk_list[x] for x in shuffle_idx]
+    seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
+    frame_label = frame_label.transpose()[shuffle_idx].transpose()
+
+    # calculate profile
+    profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
+               for spk in this_spk_list]
+    profile = np.vstack(profile)
+    # pse_weights = 2 ** np.arange(max_spk_num)
+    # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
+    # pse_label = pse_label.astype(str).tolist()
+
+    return mixed_wav, seperated_wav, profile, frame_label
+
+
+def process(task_args):
+    task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
+    logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
+
+    out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
+    wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+
+    # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
+    # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+
+    out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
+    profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+
+    out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
+    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
+
+    speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
+
+    labels_list = []
+    total_chunks = 0
+    for org_mid, label_path in task_list:
+        whole_label = kaldiio.load_mat(label_path)
+        # random offset to keep diversity
+        rand_shift = random.randint(0, args.chunk_shift)
+        num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
+        labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
+        total_chunks += num_chunk
+
+    idx = 0
+    simu_chunk_count = 0
+    for org_mid, whole_label, rand_shift, num_chunk in labels_list:
+        for i in range(num_chunk):
+            idx = idx + 1
+            st = i * args.chunk_shift + rand_shift
+            ed = i * args.chunk_shift + args.chunk_size + rand_shift
+            utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
+                args.subset + 1, task_idx + 1, org_mid, st, ed
+            )
+            frame_label = whole_label[st: ed, :]
+            sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
+            mix_wav, seg_wav, profile, frame_label = simu_chunk(
+                frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
+                speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
+            )
+            wav_mix_writer(utt_id, mix_wav)
+            # wav_sep_writer(utt_id, seg_wav)
+            profile_writer(utt_id, profile)
+            label_writer(utt_id, frame_label)
+
+            simu_chunk_count += 1
+            if simu_chunk_count % args.log_interval == 0:
+                logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
+                    task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
+    wav_mix_writer.close()
+    # wav_sep_writer.close()
+    profile_writer.close()
+    label_writer.close()
+    logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
+    return simu_chunk_count
+
+
+if __name__ == '__main__':
+    my_runner = MyRunner(process)
+    my_runner.run()
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index c3e210b..7738f4f 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -127,7 +127,7 @@
 def inference_launch(mode, **kwargs):
     if mode == "sond":
         from funasr.bin.sond_inference import inference_modelscope
-        return inference_modelscope(**kwargs)
+        return inference_modelscope(mode=mode, **kwargs)
     elif mode == "sond_demo":
         from funasr.bin.sond_inference import inference_modelscope
         param_dict = {
@@ -135,11 +135,13 @@
             "sv_train_config": "sv.yaml",
             "sv_model_file": "sv.pth",
         }
-        if "param_dict" in kwargs:
-            kwargs["param_dict"].update(param_dict)
+        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
+            for key in param_dict:
+                if key not in kwargs["param_dict"]:
+                    kwargs["param_dict"][key] = param_dict[key]
         else:
             kwargs["param_dict"] = param_dict
-        return inference_modelscope(**kwargs)
+        return inference_modelscope(mode=mode, **kwargs)
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
diff --git a/funasr/bin/diar_train.py b/funasr/bin/diar_train.py
new file mode 100755
index 0000000..f76d1b9
--- /dev/null
+++ b/funasr/bin/diar_train.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.diar import DiarTask
+
+
+# for ASR Training
+def parse_args():
+    parser = DiarTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    # for ASR Training
+    DiarTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    # re-compute batch size: when dataset type is small
+    if args.dataset_type == "small":
+        if args.batch_size is not None:
+            args.batch_size = args.batch_size * args.ngpu
+        if args.batch_bins is not None:
+            args.batch_bins = args.batch_bins * args.ngpu
+
+    main(args=args)
diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py
index 299de0d..ab6d26f 100755
--- a/funasr/bin/sond_inference.py
+++ b/funasr/bin/sond_inference.py
@@ -33,6 +33,8 @@
 from funasr.utils.types import str_or_none
 from scipy.ndimage import median_filter
 from funasr.utils.misc import statistic_model_parameters
+from funasr.datasets.iterable_dataset import load_bytes
+
 
 class Speech2Diarization:
     """Speech2Xvector class
@@ -229,6 +231,7 @@
         dur_threshold: int = 10,
         out_format: str = "vad",
         param_dict: Optional[dict] = None,
+        mode: str = "sond",
         **kwargs,
 ):
     assert check_argument_types()
@@ -252,11 +255,14 @@
     set_all_random_seed(seed)
 
     # 2a. Build speech2xvec [Optional]
-    if param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
+    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
         assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
         assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
         sv_train_config = param_dict["sv_train_config"]
         sv_model_file = param_dict["sv_model_file"]
+        if "model_dir" in param_dict:
+            sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
+            sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
         from funasr.bin.sv_inference import Speech2Xvector
         speech2xvector_kwargs = dict(
             sv_train_config=sv_train_config,
@@ -307,20 +313,25 @@
 
     def _forward(
             data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
-            raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None,
+            raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
             output_dir_v2: Optional[str] = None,
             param_dict: Optional[dict] = None,
     ):
         logging.info("param_dict: {}".format(param_dict))
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, (list, tuple)):
+                if not isinstance(raw_inputs[0], List):
+                    raw_inputs = [raw_inputs]
+
                 assert all([len(example) >= 2 for example in raw_inputs]), \
                     "The length of test case in raw_inputs must larger than 1 (>=2)."
 
                 def prepare_dataset():
                     for idx, example in enumerate(raw_inputs):
                         # read waveform file
-                        example = [soundfile.read(x)[0] if isinstance(example[0], str) else x
+                        example = [load_bytes(x) if isinstance(x, bytes) else x
+                                   for x in example]
+                        example = [soundfile.read(x)[0] if isinstance(x, str) else x
                                    for x in example]
                         # convert torch tensor to numpy array
                         example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
diff --git a/funasr/layers/label_aggregation.py b/funasr/layers/label_aggregation.py
index 075e19d..29a08a9 100644
--- a/funasr/layers/label_aggregation.py
+++ b/funasr/layers/label_aggregation.py
@@ -79,4 +79,4 @@
         else:
             olens = None
 
-        return output, olens
+        return output.to(input.dtype), olens
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 0d8b303..28df73f 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -8,6 +8,7 @@
 
 import torch
 from torch import nn
+from funasr.modules.nets_utils import make_pad_mask
 
 
 class LabelSmoothingLoss(nn.Module):
@@ -61,3 +62,20 @@
         kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
         denom = total if self.normalize_length else batch_size
         return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
+
+
+class SequenceBinaryCrossEntropy(nn.Module):
+    def __init__(
+            self,
+            normalize_length=False,
+            criterion=nn.BCEWithLogitsLoss(reduction="none")
+    ):
+        super().__init__()
+        self.normalize_length = normalize_length
+        self.criterion = criterion
+
+    def forward(self, pred, label, lengths):
+        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
+        loss = self.criterion(pred, label)
+        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
+        return loss.masked_fill(pad_mask, 0).sum() / denom
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index d29ffe5..419c813 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -7,7 +7,7 @@
 from itertools import permutations
 from typing import Dict
 from typing import Optional
-from typing import Tuple
+from typing import Tuple, List
 
 import numpy as np
 import torch
@@ -23,6 +23,8 @@
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
+from funasr.utils.misc import int2vec
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -44,17 +46,20 @@
         frontend: Optional[AbsFrontend],
         specaug: Optional[AbsSpecAug],
         normalize: Optional[AbsNormalize],
-        encoder: AbsEncoder,
-        speaker_encoder: AbsEncoder,
+        encoder: torch.nn.Module,
+        speaker_encoder: Optional[torch.nn.Module],
         ci_scorer: torch.nn.Module,
-        cd_scorer: torch.nn.Module,
+        cd_scorer: Optional[torch.nn.Module],
         decoder: torch.nn.Module,
         token_list: list,
         lsm_weight: float = 0.1,
         length_normalized_loss: bool = False,
         max_spk_num: int = 16,
         label_aggregator: Optional[torch.nn.Module] = None,
-        normlize_speech_speaker: bool = False,
+        normalize_speech_speaker: bool = False,
+        ignore_id: int = -1,
+        speaker_discrimination_loss_weight: float = 1.0,
+        inter_score_loss_weight: float = 0.0
     ):
         assert check_argument_types()
 
@@ -71,7 +76,31 @@
         self.decoder = decoder
         self.token_list = token_list
         self.max_spk_num = max_spk_num
-        self.normalize_speech_speaker = normlize_speech_speaker
+        self.normalize_speech_speaker = normalize_speech_speaker
+        self.ignore_id = ignore_id
+        self.criterion_diar = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+        self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
+        self.pse_embedding = self.generate_pse_embedding()
+        # self.register_buffer("pse_embedding", pse_embedding)
+        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
+        # self.register_buffer("power_weight", power_weight)
+        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
+        # self.register_buffer("int_token_arr", int_token_arr)
+        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
+        self.inter_score_loss_weight = inter_score_loss_weight
+        self.forward_steps = 0
+
+    def generate_pse_embedding(self):
+        embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
+        for idx, pse_label in enumerate(self.token_list):
+            emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float)
+            embedding[idx] = emb
+        return torch.from_numpy(embedding)
 
     def forward(
         self,
@@ -79,13 +108,13 @@
         speech_lengths: torch.Tensor = None,
         profile: torch.Tensor = None,
         profile_lengths: torch.Tensor = None,
-        spk_labels: torch.Tensor = None,
-        spk_labels_lengths: torch.Tensor = None,
+        binary_labels: torch.Tensor = None,
+        binary_labels_lengths: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
 
         Args:
-            speech: (Batch, samples)
+            speech: (Batch, samples) or (Batch, frames, input_size)
             speech_lengths: (Batch,) default None for chunk interator,
                                      because the chunk-iterator does not
                                      have the speech_lengths returned.
@@ -93,63 +122,44 @@
                                      espnet2/iterators/chunk_iter_factory.py
             profile: (Batch, N_spk, dim)
             profile_lengths: (Batch,)
-            spk_labels: (Batch, )
+            binary_labels: (Batch, frames, max_spk_num)
+            binary_labels_lengths: (Batch,)
         """
-        assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
+        assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
         batch_size = speech.shape[0]
+        self.forward_steps = self.forward_steps + 1
+        # 1. Network forward
+        pred, inter_outputs = self.prediction_forward(
+            speech, speech_lengths,
+            profile, profile_lengths,
+            return_inter_outputs=True
+        )
+        (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
 
-        # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
-        if self.attractor is None:
-            # 2a. Decoder (baiscally a predction layer after encoder_out)
-            pred = self.decoder(encoder_out, encoder_out_lens)
-        else:
-            # 2b. Encoder Decoder Attractors
-            # Shuffle the chronological order of encoder_out, then calculate attractor
-            encoder_out_shuffled = encoder_out.clone()
-            for i in range(len(encoder_out_lens)):
-                encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
-                    i, torch.randperm(encoder_out_lens[i]), :
-                ]
-            attractor, att_prob = self.attractor(
-                encoder_out_shuffled,
-                encoder_out_lens,
-                to_device(
-                    self,
-                    torch.zeros(
-                        encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
-                    ),
-                ),
-            )
-            # Remove the final attractor which does not correspond to a speaker
-            # Then multiply the attractors and encoder_out
-            pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
-        # 3. Aggregate time-domain labels
+        # 2. Aggregate time-domain labels to match forward outputs
         if self.label_aggregator is not None:
-            spk_labels, spk_labels_lengths = self.label_aggregator(
-                spk_labels, spk_labels_lengths
+            binary_labels, binary_labels_lengths = self.label_aggregator(
+                binary_labels, binary_labels_lengths
             )
+        # 2. Calculate power-set encoding (PSE) labels
+        raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
+        pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
 
         # If encoder uses conv* as input_layer (i.e., subsampling),
-        # the sequence length of 'pred' might be slighly less than the
+        # the sequence length of 'pred' might be slightly less than the
         # length of 'spk_labels'. Here we force them to be equal.
         length_diff_tolerance = 2
-        length_diff = spk_labels.shape[1] - pred.shape[1]
-        if length_diff > 0 and length_diff <= length_diff_tolerance:
-            spk_labels = spk_labels[:, 0 : pred.shape[1], :]
+        length_diff = pse_labels.shape[1] - pred.shape[1]
+        if 0 < length_diff <= length_diff_tolerance:
+            pse_labels = pse_labels[:, 0: pred.shape[1]]
 
-        if self.attractor is None:
-            loss_pit, loss_att = None, None
-            loss, perm_idx, perm_list, label_perm = self.pit_loss(
-                pred, spk_labels, encoder_out_lens
-            )
-        else:
-            loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
-                pred, spk_labels, encoder_out_lens
-            )
-            loss_att = self.attractor_loss(att_prob, spk_labels)
-            loss = loss_pit + self.attractor_weight * loss_att
+        loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
+        loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
+        loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
+        label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
+        loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
+                + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
+
         (
             correct,
             num_frames,
@@ -160,7 +170,11 @@
             speaker_miss,
             speaker_falarm,
             speaker_error,
-        ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
+        ) = self.calc_diarization_error(
+            pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding),
+            label=F.embedding(pse_labels * (~label_mask), self.pse_embedding),
+            length=binary_labels_lengths
+        )
 
         if speech_scored > 0 and num_frames > 0:
             sad_mr, sad_fr, mi, fa, cf, acc, der = (
@@ -177,8 +191,10 @@
 
         stats = dict(
             loss=loss.detach(),
-            loss_att=loss_att.detach() if loss_att is not None else None,
-            loss_pit=loss_pit.detach() if loss_pit is not None else None,
+            loss_diar=loss_diar.detach() if loss_diar is not None else None,
+            loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
+            loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
+            loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
             sad_mr=sad_mr,
             sad_fr=sad_fr,
             mi=mi,
@@ -186,17 +202,78 @@
             cf=cf,
             acc=acc,
             der=der,
+            forward_steps=self.forward_steps,
         )
 
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
+    def classification_loss(
+            self,
+            predictions: torch.Tensor,
+            labels: torch.Tensor,
+            prediction_lengths: torch.Tensor
+    ) -> torch.Tensor:
+        mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1])
+        pad_labels = labels.masked_fill(
+            mask.to(predictions.device),
+            value=self.ignore_id
+        )
+        loss = self.criterion_diar(predictions.contiguous(), pad_labels)
+
+        return loss
+
+    def speaker_discrimination_loss(
+            self,
+            profile: torch.Tensor,
+            profile_lengths: torch.Tensor
+    ) -> torch.Tensor:
+        profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()  # (B, N, 1)
+        mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2))  # (B, N, N)
+        mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask))
+
+        eps = 1e-12
+        coding_norm = torch.linalg.norm(
+            profile * profile_mask + (1 - profile_mask) * eps,
+            dim=2, keepdim=True
+        ) * profile_mask
+        # profile: Batch, N, dim
+        cos_theta = F.cosine_similarity(profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps) * mask
+        cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
+        loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
+
+        return loss
+
+    def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
+        mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1])
+        padding_labels = pse_labels.masked_fill(
+            mask.to(pse_labels.device),
+            value=0
+        ).to(pse_labels)
+        multi_labels = F.embedding(padding_labels, self.pse_embedding)
+
+        return multi_labels
+
+    def internal_score_loss(
+            self,
+            cd_score: torch.Tensor,
+            ci_score: torch.Tensor,
+            pse_labels: torch.Tensor,
+            pse_labels_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
+        ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
+        cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
+        return ci_loss, cd_loss
+
     def collect_feats(
         self,
         speech: torch.Tensor,
         speech_lengths: torch.Tensor,
-        spk_labels: torch.Tensor = None,
-        spk_labels_lengths: torch.Tensor = None,
+        profile: torch.Tensor = None,
+        profile_lengths: torch.Tensor = None,
+        binary_labels: torch.Tensor = None,
+        binary_labels_lengths: torch.Tensor = None,
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}
@@ -249,7 +326,7 @@
             speaker_encoder_outputs: torch.Tensor,
             seq_len: torch.Tensor = None,
             spk_len: torch.Tensor = None,
-    ) -> torch.Tensor:
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
         d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
         if self.normalize_speech_speaker:
@@ -267,9 +344,8 @@
             ci_simi = self.ci_scorer(ge_in, ge_len)[0]
         else:
             ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
-        simi = torch.cat([cd_simi, ci_simi], dim=2)
 
-        return simi
+        return ci_simi, cd_simi
 
     def post_net_forward(self, simi, seq_len):
         logits = self.decoder(simi, seq_len)[0]
@@ -282,16 +358,20 @@
             speech_lengths: torch.Tensor,
             profile: torch.Tensor,
             profile_lengths: torch.Tensor,
-    ) -> torch.Tensor:
+            return_inter_outputs: bool = False,
+    ) -> [torch.Tensor, Optional[list]]:
         # speech encoding
         speech, speech_lengths = self.encode_speech(speech, speech_lengths)
         # speaker encoding
         profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
         # calculating similarity
-        similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
+        ci_simi, cd_simi = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
+        similarity = torch.cat([cd_simi, ci_simi], dim=2)
         # post net forward
         logits = self.post_net_forward(similarity, speech_lengths)
 
+        if return_inter_outputs:
+            return logits, [(speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi)]
         return logits
 
     def encode(
@@ -318,7 +398,8 @@
             # 4. Forward encoder
             # feats: (Batch, Length, Dim)
             # -> encoder_out: (Batch, Length2, Dim)
-            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+            encoder_outputs = self.encoder(feats, feats_lengths)
+            encoder_out, encoder_out_lens = encoder_outputs[:2]
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
@@ -363,9 +444,7 @@
 
         (batch_size, max_len, num_output) = label.size()
         # mask the padding part
-        mask = np.zeros((batch_size, max_len, num_output))
-        for i in range(batch_size):
-            mask[i, : length[i], :] = 1
+        mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy()
 
         # pred and label have the shape (batch_size, max_len, num_output)
         label_np = label.data.cpu().numpy().astype(int)
diff --git a/funasr/models/encoder/ecapa_tdnn_encoder.py b/funasr/models/encoder/ecapa_tdnn_encoder.py
new file mode 100644
index 0000000..878a3c0
--- /dev/null
+++ b/funasr/models/encoder/ecapa_tdnn_encoder.py
@@ -0,0 +1,686 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class _BatchNorm1d(nn.Module):
+    def __init__(
+        self,
+        input_shape=None,
+        input_size=None,
+        eps=1e-05,
+        momentum=0.1,
+        affine=True,
+        track_running_stats=True,
+        combine_batch_time=False,
+        skip_transpose=False,
+    ):
+        super().__init__()
+        self.combine_batch_time = combine_batch_time
+        self.skip_transpose = skip_transpose
+
+        if input_size is None and skip_transpose:
+            input_size = input_shape[1]
+        elif input_size is None:
+            input_size = input_shape[-1]
+
+        self.norm = nn.BatchNorm1d(
+            input_size,
+            eps=eps,
+            momentum=momentum,
+            affine=affine,
+            track_running_stats=track_running_stats,
+        )
+
+    def forward(self, x):
+        shape_or = x.shape
+        if self.combine_batch_time:
+            if x.ndim == 3:
+                x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
+            else:
+                x = x.reshape(
+                    shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
+                )
+
+        elif not self.skip_transpose:
+            x = x.transpose(-1, 1)
+
+        x_n = self.norm(x)
+
+        if self.combine_batch_time:
+            x_n = x_n.reshape(shape_or)
+        elif not self.skip_transpose:
+            x_n = x_n.transpose(1, -1)
+
+        return x_n
+
+
+class _Conv1d(nn.Module):
+    def __init__(
+        self,
+        out_channels,
+        kernel_size,
+        input_shape=None,
+        in_channels=None,
+        stride=1,
+        dilation=1,
+        padding="same",
+        groups=1,
+        bias=True,
+        padding_mode="reflect",
+        skip_transpose=False,
+    ):
+        super().__init__()
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.dilation = dilation
+        self.padding = padding
+        self.padding_mode = padding_mode
+        self.unsqueeze = False
+        self.skip_transpose = skip_transpose
+
+        if input_shape is None and in_channels is None:
+            raise ValueError("Must provide one of input_shape or in_channels")
+
+        if in_channels is None:
+            in_channels = self._check_input_shape(input_shape)
+
+        self.conv = nn.Conv1d(
+            in_channels,
+            out_channels,
+            self.kernel_size,
+            stride=self.stride,
+            dilation=self.dilation,
+            padding=0,
+            groups=groups,
+            bias=bias,
+        )
+
+    def forward(self, x):
+        if not self.skip_transpose:
+            x = x.transpose(1, -1)
+
+        if self.unsqueeze:
+            x = x.unsqueeze(1)
+
+        if self.padding == "same":
+            x = self._manage_padding(
+                x, self.kernel_size, self.dilation, self.stride
+            )
+
+        elif self.padding == "causal":
+            num_pad = (self.kernel_size - 1) * self.dilation
+            x = F.pad(x, (num_pad, 0))
+
+        elif self.padding == "valid":
+            pass
+
+        else:
+            raise ValueError(
+                "Padding must be 'same', 'valid' or 'causal'. Got "
+                + self.padding
+            )
+
+        wx = self.conv(x)
+
+        if self.unsqueeze:
+            wx = wx.squeeze(1)
+
+        if not self.skip_transpose:
+            wx = wx.transpose(1, -1)
+
+        return wx
+
+    def _manage_padding(
+        self, x, kernel_size: int, dilation: int, stride: int,
+    ):
+        # Detecting input shape
+        L_in = x.shape[-1]
+
+        # Time padding
+        padding = get_padding_elem(L_in, stride, kernel_size, dilation)
+
+        # Applying padding
+        x = F.pad(x, padding, mode=self.padding_mode)
+
+        return x
+
+    def _check_input_shape(self, shape):
+        """Checks the input shape and returns the number of input channels.
+        """
+
+        if len(shape) == 2:
+            self.unsqueeze = True
+            in_channels = 1
+        elif self.skip_transpose:
+            in_channels = shape[1]
+        elif len(shape) == 3:
+            in_channels = shape[2]
+        else:
+            raise ValueError(
+                "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
+            )
+
+        # Kernel size must be odd
+        if self.kernel_size % 2 == 0:
+            raise ValueError(
+                "The field kernel size must be an odd number. Got %s."
+                % (self.kernel_size)
+            )
+        return in_channels
+
+
+def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
+    if stride > 1:
+        n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
+        L_out = stride * (n_steps - 1) + kernel_size * dilation
+        padding = [kernel_size // 2, kernel_size // 2]
+
+    else:
+        L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
+
+        padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
+    return padding
+
+
+# Skip transpose as much as possible for efficiency
+class Conv1d(_Conv1d):
+    def __init__(self, *args, **kwargs):
+        super().__init__(skip_transpose=True, *args, **kwargs)
+
+
+class BatchNorm1d(_BatchNorm1d):
+    def __init__(self, *args, **kwargs):
+        super().__init__(skip_transpose=True, *args, **kwargs)
+
+
+def length_to_mask(length, max_len=None, dtype=None, device=None):
+    assert len(length.shape) == 1
+
+    if max_len is None:
+        max_len = length.max().long().item()  # using arange to generate mask
+    mask = torch.arange(
+        max_len, device=length.device, dtype=length.dtype
+    ).expand(len(length), max_len) < length.unsqueeze(1)
+
+    if dtype is None:
+        dtype = length.dtype
+
+    if device is None:
+        device = length.device
+
+    mask = torch.as_tensor(mask, dtype=dtype, device=device)
+    return mask
+
+
+class TDNNBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        dilation,
+        activation=nn.ReLU,
+        groups=1,
+    ):
+        super(TDNNBlock, self).__init__()
+        self.conv = Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            dilation=dilation,
+            groups=groups,
+        )
+        self.activation = activation()
+        self.norm = BatchNorm1d(input_size=out_channels)
+
+    def forward(self, x):
+        return self.norm(self.activation(self.conv(x)))
+
+
+class Res2NetBlock(torch.nn.Module):
+    """An implementation of Res2NetBlock w/ dilation.
+
+    Arguments
+    ---------
+    in_channels : int
+        The number of channels expected in the input.
+    out_channels : int
+        The number of output channels.
+    scale : int
+        The scale of the Res2Net block.
+    kernel_size: int
+        The kernel size of the Res2Net block.
+    dilation : int
+        The dilation of the Res2Net block.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
+    ):
+        super(Res2NetBlock, self).__init__()
+        assert in_channels % scale == 0
+        assert out_channels % scale == 0
+
+        in_channel = in_channels // scale
+        hidden_channel = out_channels // scale
+
+        self.blocks = nn.ModuleList(
+            [
+                TDNNBlock(
+                    in_channel,
+                    hidden_channel,
+                    kernel_size=kernel_size,
+                    dilation=dilation,
+                )
+                for i in range(scale - 1)
+            ]
+        )
+        self.scale = scale
+
+    def forward(self, x):
+        y = []
+        for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
+            if i == 0:
+                y_i = x_i
+            elif i == 1:
+                y_i = self.blocks[i - 1](x_i)
+            else:
+                y_i = self.blocks[i - 1](x_i + y_i)
+            y.append(y_i)
+        y = torch.cat(y, dim=1)
+        return y
+
+
+class SEBlock(nn.Module):
+    """An implementation of squeeze-and-excitation block.
+
+    Arguments
+    ---------
+    in_channels : int
+        The number of input channels.
+    se_channels : int
+        The number of output channels after squeeze.
+    out_channels : int
+        The number of output channels.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> se_layer = SEBlock(64, 16, 64)
+    >>> lengths = torch.rand((8,))
+    >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(self, in_channels, se_channels, out_channels):
+        super(SEBlock, self).__init__()
+
+        self.conv1 = Conv1d(
+            in_channels=in_channels, out_channels=se_channels, kernel_size=1
+        )
+        self.relu = torch.nn.ReLU(inplace=True)
+        self.conv2 = Conv1d(
+            in_channels=se_channels, out_channels=out_channels, kernel_size=1
+        )
+        self.sigmoid = torch.nn.Sigmoid()
+
+    def forward(self, x, lengths=None):
+        L = x.shape[-1]
+        if lengths is not None:
+            mask = length_to_mask(lengths * L, max_len=L, device=x.device)
+            mask = mask.unsqueeze(1)
+            total = mask.sum(dim=2, keepdim=True)
+            s = (x * mask).sum(dim=2, keepdim=True) / total
+        else:
+            s = x.mean(dim=2, keepdim=True)
+
+        s = self.relu(self.conv1(s))
+        s = self.sigmoid(self.conv2(s))
+
+        return s * x
+
+
+class AttentiveStatisticsPooling(nn.Module):
+    """This class implements an attentive statistic pooling layer for each channel.
+    It returns the concatenated mean and std of the input tensor.
+
+    Arguments
+    ---------
+    channels: int
+        The number of input channels.
+    attention_channels: int
+        The number of attention channels.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> asp_layer = AttentiveStatisticsPooling(64)
+    >>> lengths = torch.rand((8,))
+    >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 1, 128])
+    """
+
+    def __init__(self, channels, attention_channels=128, global_context=True):
+        super().__init__()
+
+        self.eps = 1e-12
+        self.global_context = global_context
+        if global_context:
+            self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
+        else:
+            self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
+        self.tanh = nn.Tanh()
+        self.conv = Conv1d(
+            in_channels=attention_channels, out_channels=channels, kernel_size=1
+        )
+
+    def forward(self, x, lengths=None):
+        """Calculates mean and std for a batch (input tensor).
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape [N, C, L].
+        """
+        L = x.shape[-1]
+
+        def _compute_statistics(x, m, dim=2, eps=self.eps):
+            mean = (m * x).sum(dim)
+            std = torch.sqrt(
+                (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
+            )
+            return mean, std
+
+        if lengths is None:
+            lengths = torch.ones(x.shape[0], device=x.device)
+
+        # Make binary mask of shape [N, 1, L]
+        mask = length_to_mask(lengths * L, max_len=L, device=x.device)
+        mask = mask.unsqueeze(1)
+
+        # Expand the temporal context of the pooling layer by allowing the
+        # self-attention to look at global properties of the utterance.
+        if self.global_context:
+            # torch.std is unstable for backward computation
+            # https://github.com/pytorch/pytorch/issues/4320
+            total = mask.sum(dim=2, keepdim=True).float()
+            mean, std = _compute_statistics(x, mask / total)
+            mean = mean.unsqueeze(2).repeat(1, 1, L)
+            std = std.unsqueeze(2).repeat(1, 1, L)
+            attn = torch.cat([x, mean, std], dim=1)
+        else:
+            attn = x
+
+        # Apply layers
+        attn = self.conv(self.tanh(self.tdnn(attn)))
+
+        # Filter out zero-paddings
+        attn = attn.masked_fill(mask == 0, float("-inf"))
+
+        attn = F.softmax(attn, dim=2)
+        mean, std = _compute_statistics(x, attn)
+        # Append mean and std of the batch
+        pooled_stats = torch.cat((mean, std), dim=1)
+        pooled_stats = pooled_stats.unsqueeze(2)
+
+        return pooled_stats
+
+
+class SERes2NetBlock(nn.Module):
+    """An implementation of building block in ECAPA-TDNN, i.e.,
+    TDNN-Res2Net-TDNN-SEBlock.
+
+    Arguments
+    ----------
+    out_channels: int
+        The number of output channels.
+    res2net_scale: int
+        The scale of the Res2Net block.
+    kernel_size: int
+        The kernel size of the TDNN blocks.
+    dilation: int
+        The dilation of the Res2Net block.
+    activation : torch class
+        A class for constructing the activation layers.
+    groups: int
+    Number of blocked connections from input channels to output channels.
+
+    Example
+    -------
+    >>> x = torch.rand(8, 120, 64).transpose(1, 2)
+    >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
+    >>> out = conv(x).transpose(1, 2)
+    >>> out.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        res2net_scale=8,
+        se_channels=128,
+        kernel_size=1,
+        dilation=1,
+        activation=torch.nn.ReLU,
+        groups=1,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.tdnn1 = TDNNBlock(
+            in_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            activation=activation,
+            groups=groups,
+        )
+        self.res2net_block = Res2NetBlock(
+            out_channels, out_channels, res2net_scale, kernel_size, dilation
+        )
+        self.tdnn2 = TDNNBlock(
+            out_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            activation=activation,
+            groups=groups,
+        )
+        self.se_block = SEBlock(out_channels, se_channels, out_channels)
+
+        self.shortcut = None
+        if in_channels != out_channels:
+            self.shortcut = Conv1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+            )
+
+    def forward(self, x, lengths=None):
+        residual = x
+        if self.shortcut:
+            residual = self.shortcut(x)
+
+        x = self.tdnn1(x)
+        x = self.res2net_block(x)
+        x = self.tdnn2(x)
+        x = self.se_block(x, lengths)
+
+        return x + residual
+
+
+class ECAPA_TDNN(torch.nn.Module):
+    """An implementation of the speaker embedding model in a paper.
+    "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
+    TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
+
+    Arguments
+    ---------
+    activation : torch class
+        A class for constructing the activation layers.
+    channels : list of ints
+        Output channels for TDNN/SERes2Net layer.
+    kernel_sizes : list of ints
+        List of kernel sizes for each layer.
+    dilations : list of ints
+        List of dilations for kernels in each layer.
+    lin_neurons : int
+        Number of neurons in linear layers.
+    groups : list of ints
+        List of groups for kernels in each layer.
+
+    Example
+    -------
+    >>> input_feats = torch.rand([5, 120, 80])
+    >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
+    >>> outputs = compute_embedding(input_feats)
+    >>> outputs.shape
+    torch.Size([5, 1, 192])
+    """
+
+    def __init__(
+        self,
+        input_size,
+        lin_neurons=192,
+        activation=torch.nn.ReLU,
+        channels=[512, 512, 512, 512, 1536],
+        kernel_sizes=[5, 3, 3, 3, 1],
+        dilations=[1, 2, 3, 4, 1],
+        attention_channels=128,
+        res2net_scale=8,
+        se_channels=128,
+        global_context=True,
+        groups=[1, 1, 1, 1, 1],
+        window_size=20,
+        window_shift=1,
+    ):
+
+        super().__init__()
+        assert len(channels) == len(kernel_sizes)
+        assert len(channels) == len(dilations)
+        self.channels = channels
+        self.blocks = nn.ModuleList()
+        self.window_size = window_size
+        self.window_shift = window_shift
+
+        # The initial TDNN layer
+        self.blocks.append(
+            TDNNBlock(
+                input_size,
+                channels[0],
+                kernel_sizes[0],
+                dilations[0],
+                activation,
+                groups[0],
+            )
+        )
+
+        # SE-Res2Net layers
+        for i in range(1, len(channels) - 1):
+            self.blocks.append(
+                SERes2NetBlock(
+                    channels[i - 1],
+                    channels[i],
+                    res2net_scale=res2net_scale,
+                    se_channels=se_channels,
+                    kernel_size=kernel_sizes[i],
+                    dilation=dilations[i],
+                    activation=activation,
+                    groups=groups[i],
+                )
+            )
+
+        # Multi-layer feature aggregation
+        self.mfa = TDNNBlock(
+            channels[-1],
+            channels[-1],
+            kernel_sizes[-1],
+            dilations[-1],
+            activation,
+            groups=groups[-1],
+        )
+
+        # Attentive Statistical Pooling
+        self.asp = AttentiveStatisticsPooling(
+            channels[-1],
+            attention_channels=attention_channels,
+            global_context=global_context,
+        )
+        self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
+
+        # Final linear transformation
+        self.fc = Conv1d(
+            in_channels=channels[-1] * 2,
+            out_channels=lin_neurons,
+            kernel_size=1,
+        )
+
+    def windowed_pooling(self, x, lengths=None):
+        # x: Batch, Channel, Time
+        tt = x.shape[2]
+        num_chunk = int(math.ceil(tt / self.window_shift))
+        pad = self.window_size // 2
+        x = F.pad(x, (pad, pad, 0, 0), "reflect")
+        stat_list = []
+
+        for i in range(num_chunk):
+            # B x C
+            st, ed = i * self.window_shift, i * self.window_shift + self.window_size
+            x = self.asp(x[:, :, st: ed],
+                         lengths=torch.clamp(lengths - i, 0, self.window_size)
+                         if lengths is not None else None)
+            x = self.asp_bn(x)
+            x = self.fc(x)
+            stat_list.append(x)
+
+        return torch.cat(stat_list, dim=2)
+
+    def forward(self, x, lengths=None):
+        """Returns the embedding vector.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape (batch, time, channel).
+        lengths: torch.Tensor
+            Tensor of shape (batch, )
+        """
+        # Minimize transpose for efficiency
+        x = x.transpose(1, 2)
+
+        xl = []
+        for layer in self.blocks:
+            try:
+                x = layer(x, lengths=lengths)
+            except TypeError:
+                x = layer(x)
+            xl.append(x)
+
+        # Multi-layer feature aggregation
+        x = torch.cat(xl[1:], dim=1)
+        x = self.mfa(x)
+
+        if self.window_size is None:
+            # Attentive Statistical Pooling
+            x = self.asp(x, lengths=lengths)
+            x = self.asp_bn(x)
+            # Final linear transformation
+            x = self.fc(x)
+            # x = x.transpose(1, 2)
+            x = x.squeeze(2)  # -> B, C
+        else:
+            x = self.windowed_pooling(x, lengths)
+            x = x.transpose(1, 2)  # -> B, T, C
+        return x
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index f3212f1..73c51e3 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -24,6 +24,7 @@
 from funasr.layers.label_aggregation import LabelAggregate
 from funasr.models.ctc import CTC
 from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
 from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
 from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
 from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
@@ -123,8 +124,9 @@
         resnet34=ResNet34Diar,
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
+        ecapa_tdnn=ECAPA_TDNN,
     ),
-    type_check=AbsEncoder,
+    type_check=torch.nn.Module,
     default="resnet34",
 )
 speaker_encoder_choices = ClassChoices(
@@ -187,6 +189,8 @@
         specaug_choices,
         # --normalize and --normalize_conf
         normalize_choices,
+        # --label_aggregator and --label_aggregator_conf
+        label_aggregator_choices,
         # --model and --model_conf
         model_choices,
         # --encoder and --encoder_conf
@@ -368,7 +372,7 @@
             cls, train: bool = True, inference: bool = False
     ) -> Tuple[str, ...]:
         if not inference:
-            retval = ("speech", "profile", "label")
+            retval = ("speech", "profile", "binary_labels")
         else:
             # Recognition mode
             retval = ("speech", "profile")

--
Gitblit v1.9.1