From a73123bcfc14370b74b17084bc124f00c48613e4 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期六, 06 五月 2023 16:17:48 +0800
Subject: [PATCH] add speaker-attributed ASR task for alimeeting

---
 egs/alimeeting/sa-asr/utils/validate_data_dir.sh                      |  404 ++
 egs/alimeeting/sa-asr/asr_local.sh                                    | 1562 +++++++++
 egs/alimeeting/sa-asr/local/process_text_spk_merge.py                 |   55 
 egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py               |  243 +
 egs/alimeeting/sa-asr/utils/data/split_data.sh                        |  160 
 egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh                      |  143 
 setup.py                                                              |    7 
 egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh         |  129 
 egs/alimeeting/sa-asr/utils/combine_data.sh                           |  146 
 egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh                   |  162 
 egs/alimeeting/sa-asr/path.sh                                         |    6 
 egs/alimeeting/sa-asr/utils/validate_text.pl                          |  136 
 funasr/tasks/sa_asr.py                                                |  623 +++
 egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl                     |   38 
 funasr/losses/nll_loss.py                                             |   47 
 egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl                     |   27 
 egs/alimeeting/sa-asr/utils/split_scp.pl                              |  246 +
 egs/alimeeting/sa-asr/utils/parse_options.sh                          |   97 
 funasr/bin/sa_asr_train.py                                            |   55 
 egs/alimeeting/sa-asr/local/process_text_id.py                        |   24 
 egs/alimeeting/sa-asr/local/gen_oracle_embedding.py                   |   70 
 egs/alimeeting/sa-asr/local/compute_wer.py                            |  157 
 funasr/models/pooling/statistic_pooling.py                            |    4 
 egs/alimeeting/sa-asr/local/proce_text.py                             |   32 
 egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py              |  167 +
 egs/alimeeting/sa-asr/utils/copy_data_dir.sh                          |  145 
 egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py          |   86 
 funasr/tasks/abs_task.py                                              |   16 
 egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh             |   29 
 egs/alimeeting/sa-asr/utils/filter_scp.pl                             |   87 
 funasr/utils/postprocess_utils.py                                     |    7 
 funasr/bin/asr_inference.py                                           |   34 
 egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py            |  158 
 funasr/fileio/sound_scp.py                                            |    6 
 funasr/bin/asr_train.py                                               |   15 
 funasr/models/e2e_sa_asr.py                                           |  521 +++
 egs/alimeeting/sa-asr/asr_local_infer.sh                              |  590 +++
 funasr/bin/sa_asr_inference.py                                        |  674 ++++
 egs/alimeeting/sa-asr/run_m2met_2023_infer.sh                         |   50 
 egs/alimeeting/sa-asr/local/download_xvector_model.py                 |    6 
 egs/alimeeting/sa-asr/local/compute_cpcer.py                          |   91 
 egs/alimeeting/sa-asr/local/text_normalize.pl                         |   38 
 funasr/modules/attention.py                                           |   37 
 egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py                |   22 
 egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py       |  235 +
 egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py             |   68 
 funasr/models/decoder/decoder_layer_sa_asr.py                         |  169 +
 egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py |  127 
 egs/alimeeting/sa-asr/utils/fix_data_dir.sh                           |  215 +
 egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml                   |   88 
 egs/alimeeting/sa-asr/run_m2met_2023.sh                               |   51 
 egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh                       |  135 
 funasr/modules/beam_search/beam_search_sa_asr.py                      |  525 +++
 egs/alimeeting/sa-asr/local/text_format.pl                            |   14 
 egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml                  |   29 
 egs/alimeeting/sa-asr/utils/apply_map.pl                              |   97 
 egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml                |  116 
 egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py           |   59 
 egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh                 |  142 
 egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml                        |    6 
 funasr/models/decoder/transformer_decoder_sa_asr.py                   |  291 +
 funasr/models/frontend/default.py                                     |   11 
 egs/alimeeting/sa-asr/pyscripts/utils/print_args.py                   |   45 
 funasr/bin/asr_inference_launch.py                                    |    5 
 egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh         |  116 
 65 files changed, 9,859 insertions(+), 37 deletions(-)

diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
new file mode 100755
index 0000000..c0359eb
--- /dev/null
+++ b/egs/alimeeting/sa-asr/asr_local.sh
@@ -0,0 +1,1562 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+    local fname=${BASH_SOURCE[1]##*/}
+    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+min() {
+  local a b
+  a=$1
+  for b in "$@"; do
+      if [ "${b}" -le "${a}" ]; then
+          a="${b}"
+      fi
+  done
+  echo "${a}"
+}
+SECONDS=0
+
+# General configuration
+stage=1              # Processes starts from the specified stage.
+stop_stage=10000     # Processes is stopped at the specified stage.
+skip_data_prep=false # Skip data preparation stages.
+skip_train=false     # Skip training stages.
+skip_eval=false      # Skip decoding and evaluation stages.
+skip_upload=true     # Skip packing and uploading stages.
+ngpu=1               # The number of gpus ("0" uses cpu, otherwise use gpu).
+num_nodes=1          # The number of nodes.
+nj=16                # The number of parallel jobs.
+inference_nj=16      # The number of parallel jobs in decoding.
+gpu_inference=false  # Whether to perform gpu decoding.
+njob_infer=4
+dumpdir=dump2         # Directory to dump features.
+expdir=exp           # Directory to save experiments.
+python=python3       # Specify python to execute espnet commands.
+device=0
+
+# Data preparation related
+local_data_opts= # The options given to local/data.sh.
+
+# Speed perturbation related
+speed_perturb_factors=  # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
+
+# Feature extraction related
+feats_type=raw       # Feature type (raw or fbank_pitch).
+audio_format=flac    # Audio format: wav, flac, wav.ark, flac.ark  (only in feats_type=raw).
+fs=16000             # Sampling rate.
+min_wav_duration=0.1 # Minimum duration in second.
+max_wav_duration=20  # Maximum duration in second.
+
+# Tokenization related
+token_type=bpe      # Tokenization type (char or bpe).
+nbpe=30             # The number of BPE vocabulary.
+bpemode=unigram     # Mode of BPE (unigram or bpe).
+oov="<unk>"         # Out of vocabulary symbol.
+blank="<blank>"     # CTC blank symbol
+sos_eos="<sos/eos>" # sos and eos symbole
+bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
+bpe_nlsyms=         # non-linguistic symbols list, separated by a comma, for BPE
+bpe_char_cover=1.0  # character coverage when modeling BPE
+
+# Language model related
+use_lm=true       # Use language model for ASR decoding.
+lm_tag=           # Suffix to the result dir for language model training.
+lm_exp=           # Specify the direcotry path for LM experiment.
+                  # If this option is specified, lm_tag is ignored.
+lm_stats_dir=     # Specify the direcotry path for LM statistics.
+lm_config=        # Config for language model training.
+lm_args=          # Arguments for language model training, e.g., "--max_epoch 10".
+                  # Note that it will overwrite args in lm config.
+use_word_lm=false # Whether to use word language model.
+num_splits_lm=1   # Number of splitting for lm corpus.
+# shellcheck disable=SC2034
+word_vocab_size=10000 # Size of word vocabulary.
+
+# ASR model related
+asr_tag=       # Suffix to the result dir for asr model training.
+asr_exp=       # Specify the direcotry path for ASR experiment.
+               # If this option is specified, asr_tag is ignored.
+sa_asr_exp=
+asr_stats_dir= # Specify the direcotry path for ASR statistics.
+asr_config=    # Config for asr model training.
+sa_asr_config=
+asr_args=      # Arguments for asr model training, e.g., "--max_epoch 10".
+               # Note that it will overwrite args in asr config.
+feats_normalize=global_mvn # Normalizaton layer type.
+num_splits_asr=1           # Number of splitting for lm corpus.
+
+# Decoding related
+inference_tag=    # Suffix to the result dir for decoding.
+inference_config= # Config for decoding.
+inference_args=   # Arguments for decoding, e.g., "--lm_weight 0.1".
+                  # Note that it will overwrite args in inference config.
+sa_asr_inference_tag=
+sa_asr_inference_args=
+
+inference_lm=valid.loss.ave.pb        # Language modle path for decoding.
+inference_asr_model=valid.acc.ave.pb  # ASR model path for decoding.
+                                      # e.g.
+                                      # inference_asr_model=train.loss.best.pth
+                                      # inference_asr_model=3epoch.pth
+                                      # inference_asr_model=valid.acc.best.pth
+                                      # inference_asr_model=valid.loss.ave.pth
+inference_sa_asr_model=valid.acc_spk.ave.pb
+download_model= # Download a model from Model Zoo and use it for decoding.
+
+# [Task dependent] Set the datadir name created by local/data.sh
+train_set=       # Name of training set.
+valid_set=       # Name of validation set used for monitoring/tuning network training.
+test_sets=       # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
+bpe_train_text=  # Text file path of bpe training set.
+lm_train_text=   # Text file path of language model training set.
+lm_dev_text=     # Text file path of language model development set.
+lm_test_text=    # Text file path of language model evaluation set.
+nlsyms_txt=none  # Non-linguistic symbol list if existing.
+cleaner=none     # Text cleaner.
+g2p=none         # g2p method (needed if token_type=phn).
+lang=zh      # The language type of corpus.
+score_opts=                # The options given to sclite scoring
+local_score_opts=          # The options given to local/score.sh.
+
+
+help_message=$(cat << EOF
+Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
+
+Options:
+    # General configuration
+    --stage          # Processes starts from the specified stage (default="${stage}").
+    --stop_stage     # Processes is stopped at the specified stage (default="${stop_stage}").
+    --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
+    --skip_train     # Skip training stages (default="${skip_train}").
+    --skip_eval      # Skip decoding and evaluation stages (default="${skip_eval}").
+    --skip_upload    # Skip packing and uploading stages (default="${skip_upload}").
+    --ngpu           # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
+    --num_nodes      # The number of nodes (default="${num_nodes}").
+    --nj             # The number of parallel jobs (default="${nj}").
+    --inference_nj   # The number of parallel jobs in decoding (default="${inference_nj}").
+    --gpu_inference  # Whether to perform gpu decoding (default="${gpu_inference}").
+    --dumpdir        # Directory to dump features (default="${dumpdir}").
+    --expdir         # Directory to save experiments (default="${expdir}").
+    --python         # Specify python to execute espnet commands (default="${python}").
+    --device         # Which GPUs are use for local training (defalut="${device}").
+
+    # Data preparation related
+    --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
+
+    # Speed perturbation related
+    --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
+
+    # Feature extraction related
+    --feats_type       # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
+    --audio_format     # Audio format: wav, flac, wav.ark, flac.ark  (only in feats_type=raw, default="${audio_format}").
+    --fs               # Sampling rate (default="${fs}").
+    --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
+    --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
+
+    # Tokenization related
+    --token_type              # Tokenization type (char or bpe, default="${token_type}").
+    --nbpe                    # The number of BPE vocabulary (default="${nbpe}").
+    --bpemode                 # Mode of BPE (unigram or bpe, default="${bpemode}").
+    --oov                     # Out of vocabulary symbol (default="${oov}").
+    --blank                   # CTC blank symbol (default="${blank}").
+    --sos_eos                 # sos and eos symbole (default="${sos_eos}").
+    --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
+    --bpe_nlsyms              # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
+    --bpe_char_cover          # Character coverage when modeling BPE (default="${bpe_char_cover}").
+
+    # Language model related
+    --lm_tag          # Suffix to the result dir for language model training (default="${lm_tag}").
+    --lm_exp          # Specify the direcotry path for LM experiment.
+                      # If this option is specified, lm_tag is ignored (default="${lm_exp}").
+    --lm_stats_dir    # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
+    --lm_config       # Config for language model training (default="${lm_config}").
+    --lm_args         # Arguments for language model training (default="${lm_args}").
+                      # e.g., --lm_args "--max_epoch 10"
+                      # Note that it will overwrite args in lm config.
+    --use_word_lm     # Whether to use word language model (default="${use_word_lm}").
+    --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
+    --num_splits_lm   # Number of splitting for lm corpus (default="${num_splits_lm}").
+
+    # ASR model related
+    --asr_tag          # Suffix to the result dir for asr model training (default="${asr_tag}").
+    --asr_exp          # Specify the direcotry path for ASR experiment.
+                       # If this option is specified, asr_tag is ignored (default="${asr_exp}").
+    --asr_stats_dir    # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
+    --asr_config       # Config for asr model training (default="${asr_config}").
+    --asr_args         # Arguments for asr model training (default="${asr_args}").
+                       # e.g., --asr_args "--max_epoch 10"
+                       # Note that it will overwrite args in asr config.
+    --feats_normalize  # Normalizaton layer type (default="${feats_normalize}").
+    --num_splits_asr   # Number of splitting for lm corpus  (default="${num_splits_asr}").
+
+    # Decoding related
+    --inference_tag       # Suffix to the result dir for decoding (default="${inference_tag}").
+    --inference_config    # Config for decoding (default="${inference_config}").
+    --inference_args      # Arguments for decoding (default="${inference_args}").
+                          # e.g., --inference_args "--lm_weight 0.1"
+                          # Note that it will overwrite args in inference config.
+    --inference_lm        # Language modle path for decoding (default="${inference_lm}").
+    --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
+    --download_model      # Download a model from Model Zoo and use it for decoding (default="${download_model}").
+
+    # [Task dependent] Set the datadir name created by local/data.sh
+    --train_set     # Name of training set (required).
+    --valid_set     # Name of validation set used for monitoring/tuning network training (required).
+    --test_sets     # Names of test sets.
+                    # Multiple items (e.g., both dev and eval sets) can be specified (required).
+    --bpe_train_text # Text file path of bpe training set.
+    --lm_train_text  # Text file path of language model training set.
+    --lm_dev_text   # Text file path of language model development set (default="${lm_dev_text}").
+    --lm_test_text  # Text file path of language model evaluation set (default="${lm_test_text}").
+    --nlsyms_txt    # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
+    --cleaner       # Text cleaner (default="${cleaner}").
+    --g2p           # g2p method (default="${g2p}").
+    --lang          # The language type of corpus (default=${lang}).
+    --score_opts             # The options given to sclite scoring (default="{score_opts}").
+    --local_score_opts       # The options given to local/score.sh (default="{local_score_opts}").
+EOF
+)
+
+log "$0 $*"
+# Save command line args for logging (they will be lost after utils/parse_options.sh)
+run_args=$(python -m funasr.utils.cli_utils $0 "$@")
+. utils/parse_options.sh
+
+if [ $# -ne 0 ]; then
+    log "${help_message}"
+    log "Error: No positional arguments are required."
+    exit 2
+fi
+
+. ./path.sh
+
+
+# Check required arguments
+[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
+[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
+[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
+
+# Check feature type
+if [ "${feats_type}" = raw ]; then
+    data_feats=${dumpdir}/raw
+elif [ "${feats_type}" = fbank_pitch ]; then
+    data_feats=${dumpdir}/fbank_pitch
+elif [ "${feats_type}" = fbank ]; then
+    data_feats=${dumpdir}/fbank
+elif [ "${feats_type}" == extracted ]; then
+    data_feats=${dumpdir}/extracted
+else
+    log "${help_message}"
+    log "Error: not supported: --feats_type ${feats_type}"
+    exit 2
+fi
+
+# Use the same text as ASR for bpe training if not specified.
+[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
+# Use the text of the 1st evaldir if lm_test is not specified
+[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
+
+# Check tokenization type
+if [ "${lang}" != noinfo ]; then
+    token_listdir=data/${lang}_token_list
+else
+    token_listdir=data/token_list
+fi
+bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
+bpeprefix="${bpedir}"/bpe
+bpemodel="${bpeprefix}".model
+bpetoken_list="${bpedir}"/tokens.txt
+chartoken_list="${token_listdir}"/char/tokens.txt
+# NOTE: keep for future development.
+# shellcheck disable=SC2034
+wordtoken_list="${token_listdir}"/word/tokens.txt
+
+if [ "${token_type}" = bpe ]; then
+    token_list="${bpetoken_list}"
+elif [ "${token_type}" = char ]; then
+    token_list="${chartoken_list}"
+    bpemodel=none
+elif [ "${token_type}" = word ]; then
+    token_list="${wordtoken_list}"
+    bpemodel=none
+else
+    log "Error: not supported --token_type '${token_type}'"
+    exit 2
+fi
+if ${use_word_lm}; then
+    log "Error: Word LM is not supported yet"
+    exit 2
+
+    lm_token_list="${wordtoken_list}"
+    lm_token_type=word
+else
+    lm_token_list="${token_list}"
+    lm_token_type="${token_type}"
+fi
+
+
+# Set tag for naming of model directory
+if [ -z "${asr_tag}" ]; then
+    if [ -n "${asr_config}" ]; then
+        asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
+    else
+        asr_tag="train_${feats_type}"
+    fi
+    if [ "${lang}" != noinfo ]; then
+        asr_tag+="_${lang}_${token_type}"
+    else
+        asr_tag+="_${token_type}"
+    fi
+    if [ "${token_type}" = bpe ]; then
+        asr_tag+="${nbpe}"
+    fi
+    # Add overwritten arg's info
+    if [ -n "${asr_args}" ]; then
+        asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+    fi
+    if [ -n "${speed_perturb_factors}" ]; then
+        asr_tag+="_sp"
+    fi
+fi
+if [ -z "${lm_tag}" ]; then
+    if [ -n "${lm_config}" ]; then
+        lm_tag="$(basename "${lm_config}" .yaml)"
+    else
+        lm_tag="train"
+    fi
+    if [ "${lang}" != noinfo ]; then
+        lm_tag+="_${lang}_${lm_token_type}"
+    else
+        lm_tag+="_${lm_token_type}"
+    fi
+    if [ "${lm_token_type}" = bpe ]; then
+        lm_tag+="${nbpe}"
+    fi
+    # Add overwritten arg's info
+    if [ -n "${lm_args}" ]; then
+        lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+    fi
+fi
+
+# The directory used for collect-stats mode
+if [ -z "${asr_stats_dir}" ]; then
+    if [ "${lang}" != noinfo ]; then
+        asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
+    else
+        asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
+    fi
+    if [ "${token_type}" = bpe ]; then
+        asr_stats_dir+="${nbpe}"
+    fi
+    if [ -n "${speed_perturb_factors}" ]; then
+        asr_stats_dir+="_sp"
+    fi
+fi
+if [ -z "${lm_stats_dir}" ]; then
+    if [ "${lang}" != noinfo ]; then
+        lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
+    else
+        lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
+    fi
+    if [ "${lm_token_type}" = bpe ]; then
+        lm_stats_dir+="${nbpe}"
+    fi
+fi
+# The directory used for training commands
+if [ -z "${asr_exp}" ]; then
+    asr_exp="${expdir}/asr_${asr_tag}"
+fi
+if [ -z "${lm_exp}" ]; then
+    lm_exp="${expdir}/lm_${lm_tag}"
+fi
+
+
+if [ -z "${inference_tag}" ]; then
+    if [ -n "${inference_config}" ]; then
+        inference_tag="$(basename "${inference_config}" .yaml)"
+    else
+        inference_tag=inference
+    fi
+    # Add overwritten arg's info
+    if [ -n "${inference_args}" ]; then
+        inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+    fi
+    if "${use_lm}"; then
+        inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+    fi
+    inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+if [ -z "${sa_asr_inference_tag}" ]; then
+    if [ -n "${inference_config}" ]; then
+        sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
+    else
+        sa_asr_inference_tag=sa_asr_inference
+    fi
+    # Add overwritten arg's info
+    if [ -n "${sa_asr_inference_args}" ]; then
+        sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+    fi
+    if "${use_lm}"; then
+        sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+    fi
+    sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+train_cmd="run.pl"
+cuda_cmd="run.pl"
+decode_cmd="run.pl"
+
+# ========================== Main stages start from here. ==========================
+
+if ! "${skip_data_prep}"; then
+
+    if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+        log "Stage 1: Data preparation for data/${train_set}, data/${valid_set}, etc."
+
+        ./local/alimeeting_data_prep.sh --tgt Test
+        ./local/alimeeting_data_prep.sh --tgt Eval
+        ./local/alimeeting_data_prep.sh --tgt Train
+    fi
+
+    if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+        if [ -n "${speed_perturb_factors}" ]; then
+           log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
+           for factor in ${speed_perturb_factors}; do
+               if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
+                   scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
+                   _dirs+="data/${train_set}_sp${factor} "
+               else
+                   # If speed factor is 1, same as the original
+                   _dirs+="data/${train_set} "
+               fi
+           done
+           utils/combine_data.sh "data/${train_set}_sp" ${_dirs}
+        else
+           log "Skip stage 2: Speed perturbation"
+        fi
+    fi
+
+    if [ -n "${speed_perturb_factors}" ]; then
+        train_set="${train_set}_sp"
+    fi
+
+    if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+        if [ "${feats_type}" = raw ]; then
+            log "Stage 3: Format wav.scp: data/ -> ${data_feats}"
+
+            # ====== Recreating "wav.scp" ======
+            # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
+            # shouldn't be used in training process.
+            # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
+            # and it can also change the audio-format and sampling rate.
+            # If nothing is need, then format_wav_scp.sh does nothing:
+            # i.e. the input file format and rate is same as the output.
+
+            for dset in "${train_set}" "${valid_set}" "${test_sets}" ; do
+                if [ "${dset}" = "${train_set}" ] || [ "${dset}" = "${valid_set}" ]; then
+                    _suf="/org"
+                else
+                    if [ "${dset}" = "${test_sets}" ] && [ "${test_sets}" = "Test_Ali_far" ]; then
+                        _suf="/org"
+                    else
+                        _suf=""
+                    fi
+                fi
+                utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+                
+                cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
+
+                rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
+                _opts=
+                if [ -e data/"${dset}"/segments ]; then
+                    # "segments" is used for splitting wav files which are written in "wav".scp
+                    # into utterances. The file format of segments:
+                    #   <segment_id> <record_id> <start_time> <end_time>
+                    #   "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
+                    # Where the time is written in seconds.
+                    _opts+="--segments data/${dset}/segments "
+                fi
+                # shellcheck disable=SC2086
+                scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+                    --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
+                    "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
+
+                echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
+            done
+
+        else
+            log "Error: not supported: --feats_type ${feats_type}"
+            exit 2
+        fi
+    fi
+
+
+    if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+        log "Stage 4: Remove long/short data: ${data_feats}/org -> ${data_feats}"
+
+        # NOTE(kamo): Not applying to test_sets to keep original data
+        if [ "${test_sets}" = "Test_Ali_far" ]; then
+            rm_dset="${train_set} ${valid_set} ${test_sets}"
+        else
+            rm_dset="${train_set} ${valid_set}"
+        fi
+
+        for dset in $rm_dset; do
+
+            # Copy data dir
+            utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
+            cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
+
+            # Remove short utterances
+            _feats_type="$(<${data_feats}/${dset}/feats_type)"
+            if [ "${_feats_type}" = raw ]; then
+                _fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))")
+                _min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))")
+                _max_length=$(python3 -c "print(int(${max_wav_duration} * ${_fs}))")
+
+                # utt2num_samples is created by format_wav_scp.sh
+                <"${data_feats}/org/${dset}/utt2num_samples" \
+                    awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
+                        '{ if ($2 > min_length && $2 < max_length ) print $0; }' \
+                        >"${data_feats}/${dset}/utt2num_samples"
+                <"${data_feats}/org/${dset}/wav.scp" \
+                    utils/filter_scp.pl "${data_feats}/${dset}/utt2num_samples"  \
+                    >"${data_feats}/${dset}/wav.scp"
+            else
+                # Get frame shift in ms from conf/fbank.conf
+                _frame_shift=
+                if [ -f conf/fbank.conf ] && [ "$(<conf/fbank.conf grep -c frame-shift)" -gt 0 ]; then
+                    # Assume using conf/fbank.conf for feature extraction
+                    _frame_shift="$(<conf/fbank.conf grep frame-shift | sed -e 's/[-a-z =]*\([0-9]*\)/\1/g')"
+                fi
+                if [ -z "${_frame_shift}" ]; then
+                    # If not existing, use the default number in Kaldi (=10ms).
+                    # If you are using different number, you have to change the following value manually.
+                    _frame_shift=10
+                fi
+
+                _min_length=$(python3 -c "print(int(${min_wav_duration} / ${_frame_shift} * 1000))")
+                _max_length=$(python3 -c "print(int(${max_wav_duration} / ${_frame_shift} * 1000))")
+
+                cp "${data_feats}/org/${dset}/feats_dim" "${data_feats}/${dset}/feats_dim"
+                <"${data_feats}/org/${dset}/feats_shape" awk -F, ' { print $1 } ' \
+                    | awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
+                        '{ if ($2 > min_length && $2 < max_length) print $0; }' \
+                        >"${data_feats}/${dset}/feats_shape"
+                <"${data_feats}/org/${dset}/feats.scp" \
+                    utils/filter_scp.pl "${data_feats}/${dset}/feats_shape"  \
+                    >"${data_feats}/${dset}/feats.scp"
+            fi
+
+            # Remove empty text
+            <"${data_feats}/org/${dset}/text" \
+                awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
+
+            # fix_data_dir.sh leaves only utts which exist in all files
+            utils/fix_data_dir.sh "${data_feats}/${dset}"
+
+            # generate uttid
+            cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
+            # filter utt2spk_all_fifo
+            python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
+        done
+
+        # shellcheck disable=SC2002
+        cat ${lm_train_text} | awk ' { if( NF != 1 ) print $0; } ' > "${data_feats}/lm_train.txt"
+    fi
+
+
+    if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+        log "Stage 5: Dictionary Preparation"
+        mkdir -p data/${lang}_token_list/char/
+    
+        echo "make a dictionary"
+        echo "<blank>" > ${token_list}
+        echo "<s>" >> ${token_list}
+        echo "</s>" >> ${token_list}
+        local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
+            | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+        num_token=$(cat ${token_list} | wc -l)
+        echo "<unk>" >> ${token_list}
+        vocab_size=$(cat ${token_list} | wc -l)
+    fi
+
+    if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+        log "Stage 6: Generate speaker settings"
+        mkdir -p "profile_log"
+        for dset in "${train_set}" "${valid_set}" "${test_sets}"; do
+            # generate text_id spk2id
+            python local/process_sot_fifo_textchar2spk.py --path ${data_feats}/${dset}
+            log "Successfully generate ${data_feats}/${dset}/text_id ${data_feats}/${dset}/spk2id"
+            # generate text_id_train for sot
+            python local/process_text_id.py ${data_feats}/${dset}
+            log "Successfully generate ${data_feats}/${dset}/text_id_train"
+            # generate oracle_embedding from single-speaker audio segment
+            python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
+            log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
+            # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
+            if [ "${dset}" = "${train_set}" ]; then
+                python local/gen_oracle_profile_padding.py ${data_feats}/${dset}
+                log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_padding.scp)"
+            else
+                python local/gen_oracle_profile_nopadding.py ${data_feats}/${dset}
+                log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_nopadding.scp)"
+            fi
+            # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+            if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
+                python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+                log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
+            fi
+
+            done
+    fi
+
+else
+    log "Skip the stages for data preparation"
+fi
+
+
+# ========================== Data preparation is done here. ==========================
+
+
+if ! "${skip_train}"; then
+    if "${use_lm}"; then
+        if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+            log "Stage 7: LM collect stats: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
+
+            _opts=
+            if [ -n "${lm_config}" ]; then
+                # To generate the config file: e.g.
+                #   % python3 -m espnet2.bin.lm_train --print_config --optim adam
+                _opts+="--config ${lm_config} "
+            fi
+
+            # 1. Split the key file
+            _logdir="${lm_stats_dir}/logdir"
+            mkdir -p "${_logdir}"
+            # Get the minimum number among ${nj} and the number lines of input files
+            _nj=$(min "${nj}" "$(<${data_feats}/lm_train.txt wc -l)" "$(<${lm_dev_text} wc -l)")
+
+            key_file="${data_feats}/lm_train.txt"
+            split_scps=""
+            for n in $(seq ${_nj}); do
+                split_scps+=" ${_logdir}/train.${n}.scp"
+            done
+            # shellcheck disable=SC2086
+            utils/split_scp.pl "${key_file}" ${split_scps}
+
+            key_file="${lm_dev_text}"
+            split_scps=""
+            for n in $(seq ${_nj}); do
+                split_scps+=" ${_logdir}/dev.${n}.scp"
+            done
+            # shellcheck disable=SC2086
+            utils/split_scp.pl "${key_file}" ${split_scps}
+
+            # 2. Generate run.sh
+            log "Generate '${lm_stats_dir}/run.sh'. You can resume the process from stage 6 using this script"
+            mkdir -p "${lm_stats_dir}"; echo "${run_args} --stage 6 \"\$@\"; exit \$?" > "${lm_stats_dir}/run.sh"; chmod +x "${lm_stats_dir}/run.sh"
+
+            # 3. Submit jobs
+            log "LM collect-stats started... log: '${_logdir}/stats.*.log'"
+            # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
+            #       but it's used only for deciding the sample ids.
+            # shellcheck disable=SC2086
+            ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
+                ${python} -m funasr.bin.lm_train \
+                    --collect_stats true \
+                    --use_preprocessor true \
+                    --bpemodel "${bpemodel}" \
+                    --token_type "${lm_token_type}"\
+                    --token_list "${lm_token_list}" \
+                    --non_linguistic_symbols "${nlsyms_txt}" \
+                    --cleaner "${cleaner}" \
+                    --g2p "${g2p}" \
+                    --train_data_path_and_name_and_type "${data_feats}/lm_train.txt,text,text" \
+                    --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
+                    --train_shape_file "${_logdir}/train.JOB.scp" \
+                    --valid_shape_file "${_logdir}/dev.JOB.scp" \
+                    --output_dir "${_logdir}/stats.JOB" \
+                    ${_opts} ${lm_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
+
+            # 4. Aggregate shape files
+            _opts=
+            for i in $(seq "${_nj}"); do
+                _opts+="--input_dir ${_logdir}/stats.${i} "
+            done
+            # shellcheck disable=SC2086
+            ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${lm_stats_dir}"
+
+            # Append the num-tokens at the last dimensions. This is used for batch-bins count
+            <"${lm_stats_dir}/train/text_shape" \
+                awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
+                >"${lm_stats_dir}/train/text_shape.${lm_token_type}"
+
+            <"${lm_stats_dir}/valid/text_shape" \
+                awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
+                >"${lm_stats_dir}/valid/text_shape.${lm_token_type}"
+        fi
+
+
+        if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+            log "Stage 8: LM Training: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
+
+            _opts=
+            if [ -n "${lm_config}" ]; then
+                # To generate the config file: e.g.
+                #   % python3 -m espnet2.bin.lm_train --print_config --optim adam
+                _opts+="--config ${lm_config} "
+            fi
+
+            if [ "${num_splits_lm}" -gt 1 ]; then
+                # If you met a memory error when parsing text files, this option may help you.
+                # The corpus is split into subsets and each subset is used for training one by one in order,
+                # so the memory footprint can be limited to the memory required for each dataset.
+
+                _split_dir="${lm_stats_dir}/splits${num_splits_lm}"
+                if [ ! -f "${_split_dir}/.done" ]; then
+                    rm -f "${_split_dir}/.done"
+                    ${python} -m espnet2.bin.split_scps \
+                      --scps "${data_feats}/lm_train.txt" "${lm_stats_dir}/train/text_shape.${lm_token_type}" \
+                      --num_splits "${num_splits_lm}" \
+                      --output_dir "${_split_dir}"
+                    touch "${_split_dir}/.done"
+                else
+                    log "${_split_dir}/.done exists. Spliting is skipped"
+                fi
+
+                _opts+="--train_data_path_and_name_and_type ${_split_dir}/lm_train.txt,text,text "
+                _opts+="--train_shape_file ${_split_dir}/text_shape.${lm_token_type} "
+                _opts+="--multiple_iterator true "
+
+            else
+                _opts+="--train_data_path_and_name_and_type ${data_feats}/lm_train.txt,text,text "
+                _opts+="--train_shape_file ${lm_stats_dir}/train/text_shape.${lm_token_type} "
+            fi
+
+            # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+
+            log "Generate '${lm_exp}/run.sh'. You can resume the process from stage 8 using this script"
+            mkdir -p "${lm_exp}"; echo "${run_args} --stage 8 \"\$@\"; exit \$?" > "${lm_exp}/run.sh"; chmod +x "${lm_exp}/run.sh"
+
+            log "LM training started... log: '${lm_exp}/train.log'"
+            if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+                # SGE can't include "/" in a job name
+                jobname="$(basename ${lm_exp})"
+            else
+                jobname="${lm_exp}/train.log"
+            fi
+
+            mkdir -p ${lm_exp}
+            mkdir -p ${lm_exp}/log
+            INIT_FILE=${lm_exp}/ddp_init
+            if [ -f $INIT_FILE ];then
+                rm -f $INIT_FILE
+            fi 
+            init_method=file://$(readlink -f $INIT_FILE)
+            echo "$0: init method is $init_method"
+            for ((i = 0; i < $ngpu; ++i)); do
+                {
+                    # i=0
+                    rank=$i
+                    local_rank=$i
+                    gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+                    lm_train.py \
+                        --gpu_id $gpu_id \
+                        --use_preprocessor true \
+                        --bpemodel ${bpemodel} \
+                        --token_type ${token_type} \
+                        --token_list ${token_list} \
+                        --non_linguistic_symbols ${nlsyms_txt} \
+                        --cleaner ${cleaner} \
+                        --g2p ${g2p} \
+                        --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
+                        --valid_shape_file "${lm_stats_dir}/valid/text_shape.${lm_token_type}" \
+                        --resume true \
+                        --output_dir ${lm_exp} \
+                        --config $lm_config \
+                        --ngpu $ngpu \
+                        --num_worker_count 1 \
+                        --multiprocessing_distributed true \
+                        --dist_init_method $init_method \
+                        --dist_world_size $ngpu \
+                        --dist_rank $rank \
+                        --local_rank $local_rank \
+                        ${_opts} 1> ${lm_exp}/log/train.log.$i 2>&1
+                } &
+                done
+                wait
+
+        fi
+
+
+        if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
+            log "Stage 9: Calc perplexity: ${lm_test_text}"
+            _opts=
+            # TODO(kamo): Parallelize?
+            log "Perplexity calculation started... log: '${lm_exp}/perplexity_test/lm_calc_perplexity.log'"
+            # shellcheck disable=SC2086
+            CUDA_VISIBLE_DEVICES=${device}\
+            ${cuda_cmd} --gpu "${ngpu}" "${lm_exp}"/perplexity_test/lm_calc_perplexity.log \
+                ${python} -m funasr.bin.lm_calc_perplexity \
+                    --ngpu "${ngpu}" \
+                    --data_path_and_name_and_type "${lm_test_text},text,text" \
+                    --train_config "${lm_exp}"/config.yaml \
+                    --model_file "${lm_exp}/${inference_lm}" \
+                    --output_dir "${lm_exp}/perplexity_test" \
+                    ${_opts}
+            log "PPL: ${lm_test_text}: $(cat ${lm_exp}/perplexity_test/ppl)"
+
+        fi
+
+    else
+        log "Stage 7-9: Skip lm-related stages: use_lm=${use_lm}"
+    fi
+
+
+    if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
+        _asr_train_dir="${data_feats}/${train_set}"
+        _asr_valid_dir="${data_feats}/${valid_set}"
+        log "Stage 10: ASR collect stats: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+        _opts=
+        if [ -n "${asr_config}" ]; then
+            # To generate the config file: e.g.
+            #   % python3 -m espnet2.bin.asr_train --print_config --optim adam
+            _opts+="--config ${asr_config} "
+        fi
+
+        _feats_type="$(<${_asr_train_dir}/feats_type)"
+        if [ "${_feats_type}" = raw ]; then
+            _scp=wav.scp
+            if [[ "${audio_format}" == *ark* ]]; then
+                _type=kaldi_ark
+            else
+                # "sound" supports "wav", "flac", etc.
+                _type=sound
+            fi
+            _opts+="--frontend_conf fs=${fs} "
+        else
+            _scp=feats.scp
+            _type=kaldi_ark
+            _input_size="$(<${_asr_train_dir}/feats_dim)"
+            _opts+="--input_size=${_input_size} "
+        fi
+
+        # 1. Split the key file
+        _logdir="${asr_stats_dir}/logdir"
+        mkdir -p "${_logdir}"
+
+        # Get the minimum number among ${nj} and the number lines of input files
+        _nj=$(min "${nj}" "$(<${_asr_train_dir}/${_scp} wc -l)" "$(<${_asr_valid_dir}/${_scp} wc -l)")
+
+        key_file="${_asr_train_dir}/${_scp}"
+        split_scps=""
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/train.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+
+        key_file="${_asr_valid_dir}/${_scp}"
+        split_scps=""
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/valid.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+
+        # 2. Generate run.sh
+        log "Generate '${asr_stats_dir}/run.sh'. You can resume the process from stage 9 using this script"
+        mkdir -p "${asr_stats_dir}"; echo "${run_args} --stage 9 \"\$@\"; exit \$?" > "${asr_stats_dir}/run.sh"; chmod +x "${asr_stats_dir}/run.sh"
+
+        # 3. Submit jobs
+        log "ASR collect-stats started... log: '${_logdir}/stats.*.log'"
+
+        # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
+        #       but it's used only for deciding the sample ids.
+
+        # shellcheck disable=SC2086
+        ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
+            ${python} -m funasr.bin.asr_train \
+                --collect_stats true \
+                --mc true   \
+                --use_preprocessor true \
+                --bpemodel "${bpemodel}" \
+                --token_type "${token_type}" \
+                --token_list "${token_list}" \
+                --split_with_space false    \
+                --non_linguistic_symbols "${nlsyms_txt}" \
+                --cleaner "${cleaner}" \
+                --g2p "${g2p}" \
+                --train_data_path_and_name_and_type "${_asr_train_dir}/${_scp},speech,${_type}" \
+                --train_data_path_and_name_and_type "${_asr_train_dir}/text,text,text" \
+                --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
+                --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
+                --train_shape_file "${_logdir}/train.JOB.scp" \
+                --valid_shape_file "${_logdir}/valid.JOB.scp" \
+                --output_dir "${_logdir}/stats.JOB" \
+                ${_opts} ${asr_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
+
+        # 4. Aggregate shape files
+        _opts=
+        for i in $(seq "${_nj}"); do
+            _opts+="--input_dir ${_logdir}/stats.${i} "
+        done
+        # shellcheck disable=SC2086
+        ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${asr_stats_dir}"
+
+        # Append the num-tokens at the last dimensions. This is used for batch-bins count
+        <"${asr_stats_dir}/train/text_shape" \
+            awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
+            >"${asr_stats_dir}/train/text_shape.${token_type}"
+
+        <"${asr_stats_dir}/valid/text_shape" \
+            awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
+            >"${asr_stats_dir}/valid/text_shape.${token_type}"
+    fi
+
+
+    if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
+        _asr_train_dir="${data_feats}/${train_set}"
+        _asr_valid_dir="${data_feats}/${valid_set}"
+        log "Stage 11: ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+        _opts=
+        if [ -n "${asr_config}" ]; then
+            # To generate the config file: e.g.
+            #   % python3 -m espnet2.bin.asr_train --print_config --optim adam
+            _opts+="--config ${asr_config} "
+        fi
+
+        _feats_type="$(<${_asr_train_dir}/feats_type)"
+        if [ "${_feats_type}" = raw ]; then
+            _scp=wav.scp
+            # "sound" supports "wav", "flac", etc.
+            if [[ "${audio_format}" == *ark* ]]; then
+                _type=kaldi_ark
+            else
+                _type=sound
+            fi
+            _opts+="--frontend_conf fs=${fs} "
+        else
+            _scp=feats.scp
+            _type=kaldi_ark
+            _input_size="$(<${_asr_train_dir}/feats_dim)"
+            _opts+="--input_size=${_input_size} "
+
+        fi
+        if [ "${feats_normalize}" = global_mvn ]; then
+            # Default normalization is utterance_mvn and changes to global_mvn
+            _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
+        fi
+
+        if [ "${num_splits_asr}" -gt 1 ]; then
+            # If you met a memory error when parsing text files, this option may help you.
+            # The corpus is split into subsets and each subset is used for training one by one in order,
+            # so the memory footprint can be limited to the memory required for each dataset.
+
+            _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
+            if [ ! -f "${_split_dir}/.done" ]; then
+                rm -f "${_split_dir}/.done"
+                ${python} -m espnet2.bin.split_scps \
+                  --scps \
+                      "${_asr_train_dir}/${_scp}" \
+                      "${_asr_train_dir}/text" \
+                      "${asr_stats_dir}/train/speech_shape" \
+                      "${asr_stats_dir}/train/text_shape.${token_type}" \
+                  --num_splits "${num_splits_asr}" \
+                  --output_dir "${_split_dir}"
+                touch "${_split_dir}/.done"
+            else
+                log "${_split_dir}/.done exists. Spliting is skipped"
+            fi
+
+            _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
+            _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
+            _opts+="--train_shape_file ${_split_dir}/speech_shape "
+            _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
+            _opts+="--multiple_iterator true "
+
+        else
+            _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
+            _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
+            _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
+            _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
+        fi
+
+        # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
+        # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
+
+        # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+        log "ASR training started... log: '${asr_exp}/log/train.log'"
+        # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+        #     # SGE can't include "/" in a job name
+        #     jobname="$(basename ${asr_exp})"
+        # else
+        #     jobname="${asr_exp}/train.log"
+        # fi
+
+        mkdir -p ${asr_exp}
+        mkdir -p ${asr_exp}/log
+        INIT_FILE=${asr_exp}/ddp_init
+        if [ -f $INIT_FILE ];then
+            rm -f $INIT_FILE
+        fi 
+        init_method=file://$(readlink -f $INIT_FILE)
+        echo "$0: init method is $init_method"
+        for ((i = 0; i < $ngpu; ++i)); do
+            {
+                # i=0
+                rank=$i
+                local_rank=$i
+                gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+                asr_train.py \
+                    --mc true   \
+                    --gpu_id $gpu_id \
+                    --use_preprocessor true \
+                    --bpemodel ${bpemodel} \
+                    --token_type ${token_type} \
+                    --token_list ${token_list} \
+                    --split_with_space false    \
+                    --non_linguistic_symbols ${nlsyms_txt} \
+                    --cleaner ${cleaner} \
+                    --g2p ${g2p} \
+                    --valid_data_path_and_name_and_type ${_asr_valid_dir}/${_scp},speech,${_type} \
+                    --valid_data_path_and_name_and_type ${_asr_valid_dir}/text,text,text \
+                    --valid_shape_file ${asr_stats_dir}/valid/speech_shape \
+                    --valid_shape_file ${asr_stats_dir}/valid/text_shape.${token_type} \
+                    --resume true \
+                    --output_dir ${asr_exp} \
+                    --config $asr_config \
+                    --ngpu $ngpu \
+                    --num_worker_count 1 \
+                    --multiprocessing_distributed true \
+                    --dist_init_method $init_method \
+                    --dist_world_size $ngpu \
+                    --dist_rank $rank \
+                    --local_rank $local_rank \
+                    ${_opts} 1> ${asr_exp}/log/train.log.$i 2>&1
+            } &
+            done
+            wait
+
+    fi
+
+    if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
+        _asr_train_dir="${data_feats}/${train_set}"
+        _asr_valid_dir="${data_feats}/${valid_set}"
+        log "Stage 12: SA-ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+        _opts=
+        if [ -n "${sa_asr_config}" ]; then
+            # To generate the config file: e.g.
+            #   % python3 -m espnet2.bin.asr_train --print_config --optim adam
+            _opts+="--config ${sa_asr_config} "
+        fi
+
+        _feats_type="$(<${_asr_train_dir}/feats_type)"
+        if [ "${_feats_type}" = raw ]; then
+            _scp=wav.scp
+            # "sound" supports "wav", "flac", etc.
+            if [[ "${audio_format}" == *ark* ]]; then
+                _type=kaldi_ark
+            else
+                _type=sound
+            fi
+            _opts+="--frontend_conf fs=${fs} "
+        else
+            _scp=feats.scp
+            _type=kaldi_ark
+            _input_size="$(<${_asr_train_dir}/feats_dim)"
+            _opts+="--input_size=${_input_size} "
+
+        fi
+        if [ "${feats_normalize}" = global_mvn ]; then
+            # Default normalization is utterance_mvn and changes to global_mvn
+            _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
+        fi
+
+        if [ "${num_splits_asr}" -gt 1 ]; then
+            # If you met a memory error when parsing text files, this option may help you.
+            # The corpus is split into subsets and each subset is used for training one by one in order,
+            # so the memory footprint can be limited to the memory required for each dataset.
+
+            _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
+            if [ ! -f "${_split_dir}/.done" ]; then
+                rm -f "${_split_dir}/.done"
+                ${python} -m espnet2.bin.split_scps \
+                  --scps \
+                      "${_asr_train_dir}/${_scp}" \
+                      "${_asr_train_dir}/text" \
+                      "${asr_stats_dir}/train/speech_shape" \
+                      "${asr_stats_dir}/train/text_shape.${token_type}" \
+                  --num_splits "${num_splits_asr}" \
+                  --output_dir "${_split_dir}"
+                touch "${_split_dir}/.done"
+            else
+                log "${_split_dir}/.done exists. Spliting is skipped"
+            fi
+
+            _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
+            _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
+            _opts+="--train_data_path_and_name_and_type ${_split_dir}/text_id_train,text_id,text_int "
+            _opts+="--train_data_path_and_name_and_type ${_split_dir}/oracle_profile_padding.scp,profile,npy "
+            _opts+="--train_shape_file ${_split_dir}/speech_shape "
+            _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
+            _opts+="--multiple_iterator true "
+
+        else
+            _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
+            _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
+            _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/oracle_profile_padding.scp,profile,npy "
+            _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text_id_train,text_id,text_int "
+            _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
+            _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
+        fi
+
+        # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
+        # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
+
+        # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+        log "SA-ASR training started... log: '${sa_asr_exp}/log/train.log'"
+        # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+        #     # SGE can't include "/" in a job name
+        #     jobname="$(basename ${asr_exp})"
+        # else
+        #     jobname="${asr_exp}/train.log"
+        # fi
+
+        mkdir -p ${sa_asr_exp}
+        mkdir -p ${sa_asr_exp}/log
+        INIT_FILE=${sa_asr_exp}/ddp_init
+        
+        if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+            # download xvector extractor model file
+            python local/download_xvector_model.py exp
+            log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+        fi
+        
+        if [ -f $INIT_FILE ];then
+            rm -f $INIT_FILE
+        fi 
+        init_method=file://$(readlink -f $INIT_FILE)
+        echo "$0: init method is $init_method"
+        for ((i = 0; i < $ngpu; ++i)); do
+            {
+                # i=0
+                rank=$i
+                local_rank=$i
+                gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+                sa_asr_train.py \
+                    --gpu_id $gpu_id \
+                    --use_preprocessor true \
+                    --unused_parameters true \
+                    --bpemodel ${bpemodel} \
+                    --token_type ${token_type} \
+                    --token_list ${token_list} \
+                    --max_spk_num 4 \
+                    --split_with_space false    \
+                    --non_linguistic_symbols ${nlsyms_txt} \
+                    --cleaner ${cleaner} \
+                    --g2p ${g2p} \
+                    --allow_variable_data_keys true \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder"   \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc"   \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
+                    --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
+                    --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder"   \
+                    --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense"   \
+                    --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
+                    --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
+                    --valid_data_path_and_name_and_type "${_asr_valid_dir}/oracle_profile_nopadding.scp,profile,npy" \
+                    --valid_data_path_and_name_and_type "${_asr_valid_dir}/text_id_train,text_id,text_int" \
+                    --valid_shape_file "${asr_stats_dir}/valid/speech_shape" \
+                    --valid_shape_file "${asr_stats_dir}/valid/text_shape.${token_type}" \
+                    --resume true \
+                    --output_dir ${sa_asr_exp} \
+                    --config $sa_asr_config \
+                    --ngpu $ngpu \
+                    --num_worker_count 1 \
+                    --multiprocessing_distributed true \
+                    --dist_init_method $init_method \
+                    --dist_world_size $ngpu \
+                    --dist_rank $rank \
+                    --local_rank $local_rank \
+                    ${_opts} 1> ${sa_asr_exp}/log/train.log.$i 2>&1
+            } &
+            done
+            wait
+
+    fi
+
+else
+    log "Skip the training stages"
+fi
+
+
+if ! "${skip_eval}"; then
+    if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
+        log "Stage 13: Decoding multi-talker ASR: training_dir=${asr_exp}"
+
+        if ${gpu_inference}; then
+            _cmd="${cuda_cmd}"
+            inference_nj=$[${ngpu}*${njob_infer}]
+            _ngpu=1
+
+        else
+            _cmd="${decode_cmd}"
+            inference_nj=$inference_nj
+            _ngpu=0
+        fi
+
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        if "${use_lm}"; then
+            if "${use_word_lm}"; then
+                _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+            else
+                _opts+="--lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--lm_file ${lm_exp}/${inference_lm} "
+            fi
+        fi
+
+        # 2. Generate run.sh
+        log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 13 using this script"
+        mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 13 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh"
+
+        for dset in ${test_sets}; do
+            _data="${data_feats}/${dset}"
+            _dir="${asr_exp}/${inference_tag}/${dset}"
+            _logdir="${_dir}/logdir"
+            mkdir -p "${_logdir}"
+
+            _feats_type="$(<${_data}/feats_type)"
+            if [ "${_feats_type}" = raw ]; then
+                _scp=wav.scp
+                if [[ "${audio_format}" == *ark* ]]; then
+                    _type=kaldi_ark
+                else
+                    _type=sound
+                fi
+            else
+                _scp=feats.scp
+                _type=kaldi_ark
+            fi
+
+            # 1. Split the key file
+            key_file=${_data}/${_scp}
+            split_scps=""
+            _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+            echo $_nj
+            for n in $(seq "${_nj}"); do
+                split_scps+=" ${_logdir}/keys.${n}.scp"
+            done
+            # shellcheck disable=SC2086
+            utils/split_scp.pl "${key_file}" ${split_scps}
+
+            # 2. Submit decoding jobs
+            log "Decoding started... log: '${_logdir}/asr_inference.*.log'"
+            
+            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+                python -m funasr.bin.asr_inference_launch \
+                    --batch_size 1 \
+                    --nbest 1   \
+                    --ngpu "${_ngpu}" \
+                    --njob ${njob_infer} \
+                    --gpuid_list ${device} \
+                    --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+                    --key_file "${_logdir}"/keys.JOB.scp \
+                    --asr_train_config "${asr_exp}"/config.yaml \
+                    --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+                    --output_dir "${_logdir}"/output.JOB \
+                    --mode asr \
+                    ${_opts}
+
+            # 3. Concatenates the output files from each jobs
+            for f in token token_int score text; do
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+            done
+        done
+    fi
+
+
+    if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
+        log "Stage 14: Scoring multi-talker ASR"
+
+        for dset in ${test_sets}; do
+            _data="${data_feats}/${dset}"
+            _dir="${asr_exp}/${inference_tag}/${dset}"
+
+            python local/proce_text.py ${_data}/text ${_data}/text.proc
+            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+            cat ${_dir}/text.cer.txt
+            
+        done
+
+    fi
+
+    if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
+        log "Stage 15: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}"
+
+        if ${gpu_inference}; then
+            _cmd="${cuda_cmd}"
+            inference_nj=$[${ngpu}*${njob_infer}]
+            _ngpu=1
+
+        else
+            _cmd="${decode_cmd}"
+            inference_nj=$inference_nj
+            _ngpu=0
+        fi
+
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        if "${use_lm}"; then
+            if "${use_word_lm}"; then
+                _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+            else
+                _opts+="--lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--lm_file ${lm_exp}/${inference_lm} "
+            fi
+        fi
+
+        # 2. Generate run.sh
+        log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh'. You can resume the process from stage 15 using this script"
+        mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.oracle"; echo "${run_args} --stage 15 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"
+
+        for dset in ${test_sets}; do
+            _data="${data_feats}/${dset}"
+            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
+            _logdir="${_dir}/logdir"
+            mkdir -p "${_logdir}"
+
+            _feats_type="$(<${_data}/feats_type)"
+            if [ "${_feats_type}" = raw ]; then
+                _scp=wav.scp
+                if [[ "${audio_format}" == *ark* ]]; then
+                    _type=kaldi_ark
+                else
+                    _type=sound
+                fi
+            else
+                _scp=feats.scp
+                _type=kaldi_ark
+            fi
+
+            # 1. Split the key file
+            key_file=${_data}/${_scp}
+            split_scps=""
+            _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+            for n in $(seq "${_nj}"); do
+                split_scps+=" ${_logdir}/keys.${n}.scp"
+            done
+            # shellcheck disable=SC2086
+            utils/split_scp.pl "${key_file}" ${split_scps}
+
+            # 2. Submit decoding jobs
+            log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+            # shellcheck disable=SC2086
+            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+                python -m funasr.bin.asr_inference_launch \
+                    --batch_size 1 \
+                    --nbest 1   \
+                    --ngpu "${_ngpu}" \
+                    --njob ${njob_infer} \
+                    --gpuid_list ${device} \
+                    --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+                    --data_path_and_name_and_type "${_data}/oracle_profile_nopadding.scp,profile,npy" \
+                    --key_file "${_logdir}"/keys.JOB.scp \
+                    --allow_variable_data_keys true \
+                    --asr_train_config "${sa_asr_exp}"/config.yaml \
+                    --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+                    --output_dir "${_logdir}"/output.JOB \
+                    --mode sa_asr \
+                    ${_opts}
+
+
+            # 3. Concatenates the output files from each jobs
+            for f in token token_int score text text_id; do
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+            done
+        done
+    fi
+
+    if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
+        log "Stage 16: Scoring SA-ASR (oracle profile)"
+
+        for dset in ${test_sets}; do
+            _data="${data_feats}/${dset}"
+            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
+
+            python local/proce_text.py ${_data}/text ${_data}/text.proc
+            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+            cat ${_dir}/text.cer.txt
+
+            python local/process_text_spk_merge.py ${_dir}
+            python local/process_text_spk_merge.py ${_data}
+            
+            python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+            tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+            cat ${_dir}/text.cpcer.txt
+            
+        done
+
+    fi
+
+    if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then
+        log "Stage 17: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
+
+        if ${gpu_inference}; then
+            _cmd="${cuda_cmd}"
+            inference_nj=$[${ngpu}*${njob_infer}]
+            _ngpu=1
+
+        else
+            _cmd="${decode_cmd}"
+            inference_nj=$inference_nj
+            _ngpu=0
+        fi
+
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        if "${use_lm}"; then
+            if "${use_word_lm}"; then
+                _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+            else
+                _opts+="--lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--lm_file ${lm_exp}/${inference_lm} "
+            fi
+        fi
+
+        # 2. Generate run.sh
+        log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
+        mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
+
+        for dset in ${test_sets}; do
+            _data="${data_feats}/${dset}"
+            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+            _logdir="${_dir}/logdir"
+            mkdir -p "${_logdir}"
+
+            _feats_type="$(<${_data}/feats_type)"
+            if [ "${_feats_type}" = raw ]; then
+                _scp=wav.scp
+                if [[ "${audio_format}" == *ark* ]]; then
+                    _type=kaldi_ark
+                else
+                    _type=sound
+                fi
+            else
+                _scp=feats.scp
+                _type=kaldi_ark
+            fi
+
+            # 1. Split the key file
+            key_file=${_data}/${_scp}
+            split_scps=""
+            _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+            for n in $(seq "${_nj}"); do
+                split_scps+=" ${_logdir}/keys.${n}.scp"
+            done
+            # shellcheck disable=SC2086
+            utils/split_scp.pl "${key_file}" ${split_scps}
+
+            # 2. Submit decoding jobs
+            log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+            # shellcheck disable=SC2086
+            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+                python -m funasr.bin.asr_inference_launch \
+                    --batch_size 1 \
+                    --nbest 1   \
+                    --ngpu "${_ngpu}" \
+                    --njob ${njob_infer} \
+                    --gpuid_list ${device} \
+                    --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+                    --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+                    --key_file "${_logdir}"/keys.JOB.scp \
+                    --allow_variable_data_keys true \
+                    --asr_train_config "${sa_asr_exp}"/config.yaml \
+                    --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+                    --output_dir "${_logdir}"/output.JOB \
+                    --mode sa_asr \
+                    ${_opts}
+
+            # 3. Concatenates the output files from each jobs
+            for f in token token_int score text text_id; do
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+            done
+        done
+    fi
+
+    if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then
+        log "Stage 18: Scoring SA-ASR (cluster profile)"
+
+        for dset in ${test_sets}; do
+            _data="${data_feats}/${dset}"
+            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+
+            python local/proce_text.py ${_data}/text ${_data}/text.proc
+            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+            cat ${_dir}/text.cer.txt
+
+            python local/process_text_spk_merge.py ${_dir}
+            python local/process_text_spk_merge.py ${_data}
+            
+            python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+            tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+            cat ${_dir}/text.cpcer.txt
+            
+        done
+
+    fi
+
+else
+    log "Skip the evaluation stages"
+fi
+
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_infer.sh
new file mode 100755
index 0000000..8e8148f
--- /dev/null
+++ b/egs/alimeeting/sa-asr/asr_local_infer.sh
@@ -0,0 +1,590 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+    local fname=${BASH_SOURCE[1]##*/}
+    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+min() {
+  local a b
+  a=$1
+  for b in "$@"; do
+      if [ "${b}" -le "${a}" ]; then
+          a="${b}"
+      fi
+  done
+  echo "${a}"
+}
+SECONDS=0
+
+# General configuration
+stage=1              # Processes starts from the specified stage.
+stop_stage=10000     # Processes is stopped at the specified stage.
+skip_data_prep=false # Skip data preparation stages.
+skip_train=false     # Skip training stages.
+skip_eval=false      # Skip decoding and evaluation stages.
+skip_upload=true     # Skip packing and uploading stages.
+ngpu=1               # The number of gpus ("0" uses cpu, otherwise use gpu).
+num_nodes=1          # The number of nodes.
+nj=16                # The number of parallel jobs.
+inference_nj=16      # The number of parallel jobs in decoding.
+gpu_inference=false  # Whether to perform gpu decoding.
+njob_infer=4
+dumpdir=dump2         # Directory to dump features.
+expdir=exp           # Directory to save experiments.
+python=python3       # Specify python to execute espnet commands.
+device=0
+
+# Data preparation related
+local_data_opts= # The options given to local/data.sh.
+
+# Speed perturbation related
+speed_perturb_factors=  # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
+
+# Feature extraction related
+feats_type=raw       # Feature type (raw or fbank_pitch).
+audio_format=flac    # Audio format: wav, flac, wav.ark, flac.ark  (only in feats_type=raw).
+fs=16000             # Sampling rate.
+min_wav_duration=0.1 # Minimum duration in second.
+max_wav_duration=20  # Maximum duration in second.
+
+# Tokenization related
+token_type=bpe      # Tokenization type (char or bpe).
+nbpe=30             # The number of BPE vocabulary.
+bpemode=unigram     # Mode of BPE (unigram or bpe).
+oov="<unk>"         # Out of vocabulary symbol.
+blank="<blank>"     # CTC blank symbol
+sos_eos="<sos/eos>" # sos and eos symbole
+bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
+bpe_nlsyms=         # non-linguistic symbols list, separated by a comma, for BPE
+bpe_char_cover=1.0  # character coverage when modeling BPE
+
+# Language model related
+use_lm=true       # Use language model for ASR decoding.
+lm_tag=           # Suffix to the result dir for language model training.
+lm_exp=           # Specify the direcotry path for LM experiment.
+                  # If this option is specified, lm_tag is ignored.
+lm_stats_dir=     # Specify the direcotry path for LM statistics.
+lm_config=        # Config for language model training.
+lm_args=          # Arguments for language model training, e.g., "--max_epoch 10".
+                  # Note that it will overwrite args in lm config.
+use_word_lm=false # Whether to use word language model.
+num_splits_lm=1   # Number of splitting for lm corpus.
+# shellcheck disable=SC2034
+word_vocab_size=10000 # Size of word vocabulary.
+
+# ASR model related
+asr_tag=       # Suffix to the result dir for asr model training.
+asr_exp=       # Specify the direcotry path for ASR experiment.
+               # If this option is specified, asr_tag is ignored.
+sa_asr_exp=
+asr_stats_dir= # Specify the direcotry path for ASR statistics.
+asr_config=    # Config for asr model training.
+sa_asr_config=
+asr_args=      # Arguments for asr model training, e.g., "--max_epoch 10".
+               # Note that it will overwrite args in asr config.
+feats_normalize=global_mvn # Normalizaton layer type.
+num_splits_asr=1           # Number of splitting for lm corpus.
+
+# Decoding related
+inference_tag=    # Suffix to the result dir for decoding.
+inference_config= # Config for decoding.
+inference_args=   # Arguments for decoding, e.g., "--lm_weight 0.1".
+                  # Note that it will overwrite args in inference config.
+sa_asr_inference_tag=
+sa_asr_inference_args=
+
+inference_lm=valid.loss.ave.pb        # Language modle path for decoding.
+inference_asr_model=valid.acc.ave.pb  # ASR model path for decoding.
+                                      # e.g.
+                                      # inference_asr_model=train.loss.best.pth
+                                      # inference_asr_model=3epoch.pth
+                                      # inference_asr_model=valid.acc.best.pth
+                                      # inference_asr_model=valid.loss.ave.pth
+inference_sa_asr_model=valid.acc_spk.ave.pb
+download_model= # Download a model from Model Zoo and use it for decoding.
+
+# [Task dependent] Set the datadir name created by local/data.sh
+train_set=       # Name of training set.
+valid_set=       # Name of validation set used for monitoring/tuning network training.
+test_sets=       # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
+bpe_train_text=  # Text file path of bpe training set.
+lm_train_text=   # Text file path of language model training set.
+lm_dev_text=     # Text file path of language model development set.
+lm_test_text=    # Text file path of language model evaluation set.
+nlsyms_txt=none  # Non-linguistic symbol list if existing.
+cleaner=none     # Text cleaner.
+g2p=none         # g2p method (needed if token_type=phn).
+lang=zh      # The language type of corpus.
+score_opts=                # The options given to sclite scoring
+local_score_opts=          # The options given to local/score.sh.
+
+help_message=$(cat << EOF
+Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
+
+Options:
+    # General configuration
+    --stage          # Processes starts from the specified stage (default="${stage}").
+    --stop_stage     # Processes is stopped at the specified stage (default="${stop_stage}").
+    --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
+    --skip_train     # Skip training stages (default="${skip_train}").
+    --skip_eval      # Skip decoding and evaluation stages (default="${skip_eval}").
+    --skip_upload    # Skip packing and uploading stages (default="${skip_upload}").
+    --ngpu           # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
+    --num_nodes      # The number of nodes (default="${num_nodes}").
+    --nj             # The number of parallel jobs (default="${nj}").
+    --inference_nj   # The number of parallel jobs in decoding (default="${inference_nj}").
+    --gpu_inference  # Whether to perform gpu decoding (default="${gpu_inference}").
+    --dumpdir        # Directory to dump features (default="${dumpdir}").
+    --expdir         # Directory to save experiments (default="${expdir}").
+    --python         # Specify python to execute espnet commands (default="${python}").
+    --device         # Which GPUs are use for local training (defalut="${device}").
+
+    # Data preparation related
+    --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
+
+    # Speed perturbation related
+    --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
+
+    # Feature extraction related
+    --feats_type       # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
+    --audio_format     # Audio format: wav, flac, wav.ark, flac.ark  (only in feats_type=raw, default="${audio_format}").
+    --fs               # Sampling rate (default="${fs}").
+    --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
+    --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
+
+    # Tokenization related
+    --token_type              # Tokenization type (char or bpe, default="${token_type}").
+    --nbpe                    # The number of BPE vocabulary (default="${nbpe}").
+    --bpemode                 # Mode of BPE (unigram or bpe, default="${bpemode}").
+    --oov                     # Out of vocabulary symbol (default="${oov}").
+    --blank                   # CTC blank symbol (default="${blank}").
+    --sos_eos                 # sos and eos symbole (default="${sos_eos}").
+    --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
+    --bpe_nlsyms              # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
+    --bpe_char_cover          # Character coverage when modeling BPE (default="${bpe_char_cover}").
+
+    # Language model related
+    --lm_tag          # Suffix to the result dir for language model training (default="${lm_tag}").
+    --lm_exp          # Specify the direcotry path for LM experiment.
+                      # If this option is specified, lm_tag is ignored (default="${lm_exp}").
+    --lm_stats_dir    # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
+    --lm_config       # Config for language model training (default="${lm_config}").
+    --lm_args         # Arguments for language model training (default="${lm_args}").
+                      # e.g., --lm_args "--max_epoch 10"
+                      # Note that it will overwrite args in lm config.
+    --use_word_lm     # Whether to use word language model (default="${use_word_lm}").
+    --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
+    --num_splits_lm   # Number of splitting for lm corpus (default="${num_splits_lm}").
+
+    # ASR model related
+    --asr_tag          # Suffix to the result dir for asr model training (default="${asr_tag}").
+    --asr_exp          # Specify the direcotry path for ASR experiment.
+                       # If this option is specified, asr_tag is ignored (default="${asr_exp}").
+    --asr_stats_dir    # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
+    --asr_config       # Config for asr model training (default="${asr_config}").
+    --asr_args         # Arguments for asr model training (default="${asr_args}").
+                       # e.g., --asr_args "--max_epoch 10"
+                       # Note that it will overwrite args in asr config.
+    --feats_normalize  # Normalizaton layer type (default="${feats_normalize}").
+    --num_splits_asr   # Number of splitting for lm corpus  (default="${num_splits_asr}").
+
+    # Decoding related
+    --inference_tag       # Suffix to the result dir for decoding (default="${inference_tag}").
+    --inference_config    # Config for decoding (default="${inference_config}").
+    --inference_args      # Arguments for decoding (default="${inference_args}").
+                          # e.g., --inference_args "--lm_weight 0.1"
+                          # Note that it will overwrite args in inference config.
+    --inference_lm        # Language modle path for decoding (default="${inference_lm}").
+    --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
+    --download_model      # Download a model from Model Zoo and use it for decoding (default="${download_model}").
+
+    # [Task dependent] Set the datadir name created by local/data.sh
+    --train_set     # Name of training set (required).
+    --valid_set     # Name of validation set used for monitoring/tuning network training (required).
+    --test_sets     # Names of test sets.
+                    # Multiple items (e.g., both dev and eval sets) can be specified (required).
+    --bpe_train_text # Text file path of bpe training set.
+    --lm_train_text  # Text file path of language model training set.
+    --lm_dev_text   # Text file path of language model development set (default="${lm_dev_text}").
+    --lm_test_text  # Text file path of language model evaluation set (default="${lm_test_text}").
+    --nlsyms_txt    # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
+    --cleaner       # Text cleaner (default="${cleaner}").
+    --g2p           # g2p method (default="${g2p}").
+    --lang          # The language type of corpus (default=${lang}).
+    --score_opts             # The options given to sclite scoring (default="{score_opts}").
+    --local_score_opts       # The options given to local/score.sh (default="{local_score_opts}").
+EOF
+)
+
+log "$0 $*"
+# Save command line args for logging (they will be lost after utils/parse_options.sh)
+run_args=$(python -m funasr.utils.cli_utils $0 "$@")
+. utils/parse_options.sh
+
+if [ $# -ne 0 ]; then
+    log "${help_message}"
+    log "Error: No positional arguments are required."
+    exit 2
+fi
+
+. ./path.sh
+
+
+# Check required arguments
+[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
+[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
+[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
+
+# Check feature type
+if [ "${feats_type}" = raw ]; then
+    data_feats=${dumpdir}/raw
+elif [ "${feats_type}" = fbank_pitch ]; then
+    data_feats=${dumpdir}/fbank_pitch
+elif [ "${feats_type}" = fbank ]; then
+    data_feats=${dumpdir}/fbank
+elif [ "${feats_type}" == extracted ]; then
+    data_feats=${dumpdir}/extracted
+else
+    log "${help_message}"
+    log "Error: not supported: --feats_type ${feats_type}"
+    exit 2
+fi
+
+# Use the same text as ASR for bpe training if not specified.
+[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
+# Use the text of the 1st evaldir if lm_test is not specified
+[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
+
+# Check tokenization type
+if [ "${lang}" != noinfo ]; then
+    token_listdir=data/${lang}_token_list
+else
+    token_listdir=data/token_list
+fi
+bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
+bpeprefix="${bpedir}"/bpe
+bpemodel="${bpeprefix}".model
+bpetoken_list="${bpedir}"/tokens.txt
+chartoken_list="${token_listdir}"/char/tokens.txt
+# NOTE: keep for future development.
+# shellcheck disable=SC2034
+wordtoken_list="${token_listdir}"/word/tokens.txt
+
+if [ "${token_type}" = bpe ]; then
+    token_list="${bpetoken_list}"
+elif [ "${token_type}" = char ]; then
+    token_list="${chartoken_list}"
+    bpemodel=none
+elif [ "${token_type}" = word ]; then
+    token_list="${wordtoken_list}"
+    bpemodel=none
+else
+    log "Error: not supported --token_type '${token_type}'"
+    exit 2
+fi
+if ${use_word_lm}; then
+    log "Error: Word LM is not supported yet"
+    exit 2
+
+    lm_token_list="${wordtoken_list}"
+    lm_token_type=word
+else
+    lm_token_list="${token_list}"
+    lm_token_type="${token_type}"
+fi
+
+
+# Set tag for naming of model directory
+if [ -z "${asr_tag}" ]; then
+    if [ -n "${asr_config}" ]; then
+        asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
+    else
+        asr_tag="train_${feats_type}"
+    fi
+    if [ "${lang}" != noinfo ]; then
+        asr_tag+="_${lang}_${token_type}"
+    else
+        asr_tag+="_${token_type}"
+    fi
+    if [ "${token_type}" = bpe ]; then
+        asr_tag+="${nbpe}"
+    fi
+    # Add overwritten arg's info
+    if [ -n "${asr_args}" ]; then
+        asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+    fi
+    if [ -n "${speed_perturb_factors}" ]; then
+        asr_tag+="_sp"
+    fi
+fi
+if [ -z "${lm_tag}" ]; then
+    if [ -n "${lm_config}" ]; then
+        lm_tag="$(basename "${lm_config}" .yaml)"
+    else
+        lm_tag="train"
+    fi
+    if [ "${lang}" != noinfo ]; then
+        lm_tag+="_${lang}_${lm_token_type}"
+    else
+        lm_tag+="_${lm_token_type}"
+    fi
+    if [ "${lm_token_type}" = bpe ]; then
+        lm_tag+="${nbpe}"
+    fi
+    # Add overwritten arg's info
+    if [ -n "${lm_args}" ]; then
+        lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+    fi
+fi
+
+# The directory used for collect-stats mode
+if [ -z "${asr_stats_dir}" ]; then
+    if [ "${lang}" != noinfo ]; then
+        asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
+    else
+        asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
+    fi
+    if [ "${token_type}" = bpe ]; then
+        asr_stats_dir+="${nbpe}"
+    fi
+    if [ -n "${speed_perturb_factors}" ]; then
+        asr_stats_dir+="_sp"
+    fi
+fi
+if [ -z "${lm_stats_dir}" ]; then
+    if [ "${lang}" != noinfo ]; then
+        lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
+    else
+        lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
+    fi
+    if [ "${lm_token_type}" = bpe ]; then
+        lm_stats_dir+="${nbpe}"
+    fi
+fi
+# The directory used for training commands
+if [ -z "${asr_exp}" ]; then
+    asr_exp="${expdir}/asr_${asr_tag}"
+fi
+if [ -z "${lm_exp}" ]; then
+    lm_exp="${expdir}/lm_${lm_tag}"
+fi
+
+
+if [ -z "${inference_tag}" ]; then
+    if [ -n "${inference_config}" ]; then
+        inference_tag="$(basename "${inference_config}" .yaml)"
+    else
+        inference_tag=inference
+    fi
+    # Add overwritten arg's info
+    if [ -n "${inference_args}" ]; then
+        inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+    fi
+    if "${use_lm}"; then
+        inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+    fi
+    inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+if [ -z "${sa_asr_inference_tag}" ]; then
+    if [ -n "${inference_config}" ]; then
+        sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
+    else
+        sa_asr_inference_tag=sa_asr_inference
+    fi
+    # Add overwritten arg's info
+    if [ -n "${sa_asr_inference_args}" ]; then
+        sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+    fi
+    if "${use_lm}"; then
+        sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+    fi
+    sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+train_cmd="run.pl"
+cuda_cmd="run.pl"
+decode_cmd="run.pl"
+
+# ========================== Main stages start from here. ==========================
+
+if ! "${skip_data_prep}"; then
+
+    if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+        if [ "${feats_type}" = raw ]; then
+            log "Stage 1: Format wav.scp: data/ -> ${data_feats}"
+
+            # ====== Recreating "wav.scp" ======
+            # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
+            # shouldn't be used in training process.
+            # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
+            # and it can also change the audio-format and sampling rate.
+            # If nothing is need, then format_wav_scp.sh does nothing:
+            # i.e. the input file format and rate is same as the output.
+
+            for dset in "${test_sets}" ; do
+            
+                _suf=""
+
+                utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+                
+                rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
+                _opts=
+                if [ -e data/"${dset}"/segments ]; then
+                    # "segments" is used for splitting wav files which are written in "wav".scp
+                    # into utterances. The file format of segments:
+                    #   <segment_id> <record_id> <start_time> <end_time>
+                    #   "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
+                    # Where the time is written in seconds.
+                    _opts+="--segments data/${dset}/segments "
+                fi
+                # shellcheck disable=SC2086
+                scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+                    --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
+                    "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
+
+                echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
+            done
+
+        else
+            log "Error: not supported: --feats_type ${feats_type}"
+            exit 2
+        fi
+    fi
+
+    if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+        log "Stage 2: Generate speaker profile by spectral-cluster"
+        mkdir -p "profile_log"
+        for dset in "${test_sets}"; do
+            # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+            python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+            log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
+            done
+    fi
+
+else
+    log "Skip the stages for data preparation"
+fi
+
+
+# ========================== Data preparation is done here. ==========================
+
+if ! "${skip_eval}"; then
+
+    if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+        log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
+
+        if ${gpu_inference}; then
+            _cmd="${cuda_cmd}"
+            inference_nj=$[${ngpu}*${njob_infer}]
+            _ngpu=1
+
+        else
+            _cmd="${decode_cmd}"
+            inference_nj=$njob_infer
+            _ngpu=0
+        fi
+
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        if "${use_lm}"; then
+            if "${use_word_lm}"; then
+                _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+            else
+                _opts+="--lm_train_config ${lm_exp}/config.yaml "
+                _opts+="--lm_file ${lm_exp}/${inference_lm} "
+            fi
+        fi
+
+        # 2. Generate run.sh
+        log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
+        mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
+
+        for dset in ${test_sets}; do
+            _data="${data_feats}/${dset}"
+            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+            _logdir="${_dir}/logdir"
+            mkdir -p "${_logdir}"
+
+            _feats_type="$(<${_data}/feats_type)"
+            if [ "${_feats_type}" = raw ]; then
+                _scp=wav.scp
+                if [[ "${audio_format}" == *ark* ]]; then
+                    _type=kaldi_ark
+                else
+                    _type=sound
+                fi
+            else
+                _scp=feats.scp
+                _type=kaldi_ark
+            fi
+
+            # 1. Split the key file
+            key_file=${_data}/${_scp}
+            split_scps=""
+            _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+            for n in $(seq "${_nj}"); do
+                split_scps+=" ${_logdir}/keys.${n}.scp"
+            done
+            # shellcheck disable=SC2086
+            utils/split_scp.pl "${key_file}" ${split_scps}
+
+            # 2. Submit decoding jobs
+            log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+            # shellcheck disable=SC2086
+            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+                python -m funasr.bin.asr_inference_launch \
+                    --batch_size 1 \
+                    --nbest 1   \
+                    --ngpu "${_ngpu}" \
+                    --njob ${njob_infer} \
+                    --gpuid_list ${device} \
+                    --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+                    --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+                    --key_file "${_logdir}"/keys.JOB.scp \
+                    --allow_variable_data_keys true \
+                    --asr_train_config "${sa_asr_exp}"/config.yaml \
+                    --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+                    --output_dir "${_logdir}"/output.JOB \
+                    --mode sa_asr \
+                    ${_opts}
+
+            # 3. Concatenates the output files from each jobs
+            for f in token token_int score text text_id; do
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+            done
+        done
+    fi
+
+    if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+        log "Stage 4: Generate SA-ASR results (cluster profile)"
+
+        for dset in ${test_sets}; do
+            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+
+            python local/process_text_spk_merge.py ${_dir}
+        done
+
+    fi
+
+else
+    log "Skip the evaluation stages"
+fi
+
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
new file mode 100644
index 0000000..88fdbc2
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
@@ -0,0 +1,6 @@
+beam_size: 20
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.6
+lm_weight: 0.3
diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
new file mode 100644
index 0000000..a8c9968
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
@@ -0,0 +1,88 @@
+# network architecture
+frontend: default
+frontend_conf:
+    n_fft: 400
+    win_length: 400
+    hop_length: 160
+    use_channel: 0
+    
+# encoder related
+encoder: conformer
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    rel_pos_type: latest
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+# decoder related
+decoder: transformer
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+# ctc related
+ctc_conf:
+    ignore_nan_grad: true
+
+# hybrid CTC/attention
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+# minibatch related
+batch_type: numel
+batch_bins: 10000000  # reduce/increase this number according to your GPU memory
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+    - valid
+    - acc
+best_model_criterion:
+-   - valid
+    - acc
+    - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
new file mode 100644
index 0000000..68520ae
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
@@ -0,0 +1,29 @@
+lm: transformer
+lm_conf:
+    pos_enc: null
+    embed_unit: 128
+    att_unit: 512
+    head: 8
+    unit: 2048
+    layer: 16
+    dropout_rate: 0.1
+
+# optimization related
+grad_clip: 5.0
+batch_type: numel
+batch_bins: 500000 # 4gpus * 500000
+accum_grad: 1
+max_epoch: 15  # 15epoch is enougth
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 25000
+
+best_model_criterion:
+-   - valid
+    - loss
+    - min
+keep_nbest_models: 10  # 10 is good.
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
new file mode 100644
index 0000000..e91db18
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
@@ -0,0 +1,116 @@
+# network architecture
+frontend: default
+frontend_conf:
+    n_fft: 400
+    win_length: 400
+    hop_length: 160
+    use_channel: 0
+
+# encoder related
+asr_encoder: conformer
+asr_encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+spk_encoder: resnet34_diar
+spk_encoder_conf:
+  use_head_conv: true
+  batchnorm_momentum: 0.5
+  use_head_maxpool: false
+  num_nodes_pooling_layer: 256
+  layers_in_block:
+    - 3
+    - 4
+    - 6
+    - 3
+  filters_in_block:
+    - 32
+    - 64
+    - 128
+    - 256
+  pooling_type: statistic
+  num_nodes_resnet1: 256
+  num_nodes_last_layer: 256
+  batchnorm_momentum: 0.5
+
+# decoder related
+decoder: sa_decoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    asr_num_blocks: 6
+    spk_num_blocks: 3
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+# hybrid CTC/attention
+model_conf:
+    spk_weight: 0.5
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+ctc_conf:
+    ignore_nan_grad: true
+
+# minibatch related
+batch_type: numel
+batch_bins: 10000000
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 60
+val_scheduler_criterion:
+    - valid
+    - loss
+best_model_criterion:
+-   - valid
+    - acc
+    - max
+-   - valid
+    - acc_spk
+    - max
+-   - valid
+    - loss
+    - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 8000
+
+specaug: specaug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
new file mode 100755
index 0000000..8151bae
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
@@ -0,0 +1,162 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+    local fname=${BASH_SOURCE[1]##*/}
+    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+help_messge=$(cat << EOF
+Usage: $0
+
+Options:
+    --no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
+    --tgt (string): Which set to process, test or train.
+EOF
+)
+
+SECONDS=0
+tgt=Train #Train or Eval
+
+
+log "$0 $*"
+echo $tgt
+. ./utils/parse_options.sh
+
+. ./path.sh
+
+AliMeeting="${PWD}/dataset"
+
+if [ $# -gt 2 ]; then
+    log "${help_message}"
+    exit 2
+fi
+
+
+if [ ! -d "${AliMeeting}" ]; then
+  log "Error: ${AliMeeting} is empty."
+  exit 2
+fi
+
+# To absolute path
+AliMeeting=$(cd ${AliMeeting}; pwd)
+echo $AliMeeting
+far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
+near_raw_dir=${AliMeeting}/${tgt}_Ali_near/
+
+far_dir=data/local/${tgt}_Ali_far
+near_dir=data/local/${tgt}_Ali_near
+far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
+mkdir -p $far_single_speaker_dir
+
+stage=1
+stop_stage=4
+mkdir -p $far_dir
+mkdir -p $near_dir
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 
+    log "stage 1:process alimeeting near dir"
+    
+    find -L $near_raw_dir/audio_dir -iname "*.wav" >  $near_dir/wavlist
+    awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid   
+    find -L $near_raw_dir/textgrid_dir  -iname "*.TextGrid" > $near_dir/textgrid.flist
+    n1_wav=$(wc -l < $near_dir/wavlist)
+    n2_text=$(wc -l < $near_dir/textgrid.flist)
+    log  near file found $n1_wav wav and $n2_text text.
+
+    paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
+
+    # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -c 1 -t wav  - |\n", $1, $2)}'  > $near_dir/wav.scp
+    cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -t wav  - |\n", $1, $2)}'  > $near_dir/wav.scp
+    
+    python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
+    cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
+    utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
+    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $near_dir/utt2spk_old >$near_dir/tmp1
+    #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
+    utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
+    utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
+    sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
+    sed -e 's/锛�//g' $near_dir/tmp1> $near_dir/tmp2
+    sed -e 's/锛�//g' $near_dir/tmp2> $near_dir/text
+
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    log "stage 2:process alimeeting far dir"
+    
+    find -L $far_raw_dir/audio_dir -iname "*.wav" >  $far_dir/wavlist
+    awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid   
+    find -L $far_raw_dir/textgrid_dir  -iname "*.TextGrid" > $far_dir/textgrid.flist
+    n1_wav=$(wc -l < $far_dir/wavlist)
+    n2_text=$(wc -l < $far_dir/textgrid.flist)
+    log  far file found $n1_wav wav and $n2_text text.
+
+    paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
+
+    cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -t wav  - |\n", $1, $2)}'  > $far_dir/wav.scp
+
+    python local/alimeeting_process_overlap_force.py  --path $far_dir \
+        --no-overlap false --mars True \
+        --overlap_length 0.8 --max_length 7
+
+    cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
+    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
+    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $far_dir/utt2spk_old >$far_dir/utt2spk
+    
+    utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
+    sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
+    sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
+    sed -e 's/锛�//g' $far_dir/tmp2> $far_dir/tmp3
+    sed -e 's/锛�//g' $far_dir/tmp3> $far_dir/text
+fi
+
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    log "stage 3: finali data process"
+
+    utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
+    utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+
+    sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
+    sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+
+    # remove space in text
+    for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
+        cp data/${x}/text data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+        > data/${x}/text
+        rm data/${x}/text.org
+    done
+
+    log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
+    cp -r $far_dir/* $far_single_speaker_dir 
+    mv $far_single_speaker_dir/textgrid.flist  $far_single_speaker_dir/textgrid_oldpath
+    paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
+    python local/process_textgrid_to_single_speaker_wav.py  --path $far_single_speaker_dir
+    
+    cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text    
+    utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+
+    ./utils/fix_data_dir.sh $far_single_speaker_dir 
+    utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+
+    # remove space in text
+    for x in ${tgt}_Ali_far_single_speaker; do
+        cp data/${x}/text data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+        > data/${x}/text
+        rm data/${x}/text.org
+    done
+    log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
new file mode 100755
index 0000000..382a056
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
@@ -0,0 +1,129 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+    local fname=${BASH_SOURCE[1]##*/}
+    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+help_messge=$(cat << EOF
+Usage: $0
+
+Options:
+    --no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
+    --tgt (string): Which set to process, test or train.
+EOF
+)
+
+SECONDS=0
+tgt=Train #Train or Eval
+
+
+log "$0 $*"
+echo $tgt
+. ./utils/parse_options.sh
+
+. ./path.sh
+
+AliMeeting="${PWD}/dataset"
+
+if [ $# -gt 2 ]; then
+    log "${help_message}"
+    exit 2
+fi
+
+
+if [ ! -d "${AliMeeting}" ]; then
+  log "Error: ${AliMeeting} is empty."
+  exit 2
+fi
+
+# To absolute path
+AliMeeting=$(cd ${AliMeeting}; pwd)
+echo $AliMeeting
+far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
+
+far_dir=data/local/${tgt}_Ali_far
+far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
+mkdir -p $far_single_speaker_dir
+
+stage=1
+stop_stage=3
+mkdir -p $far_dir
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    log "stage 1:process alimeeting far dir"
+    
+    find -L $far_raw_dir/audio_dir -iname "*.wav" >  $far_dir/wavlist
+    awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid   
+    find -L $far_raw_dir/textgrid_dir  -iname "*.TextGrid" > $far_dir/textgrid.flist
+    n1_wav=$(wc -l < $far_dir/wavlist)
+    n2_text=$(wc -l < $far_dir/textgrid.flist)
+    log  far file found $n1_wav wav and $n2_text text.
+
+    paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
+
+    cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -t wav  - |\n", $1, $2)}'  > $far_dir/wav.scp
+
+    python local/alimeeting_process_overlap_force.py  --path $far_dir \
+        --no-overlap false --mars True \
+        --overlap_length 0.8 --max_length 7
+
+    cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
+    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
+    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $far_dir/utt2spk_old >$far_dir/utt2spk
+    
+    utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
+    sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
+    sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
+    sed -e 's/锛�//g' $far_dir/tmp2> $far_dir/tmp3
+    sed -e 's/锛�//g' $far_dir/tmp3> $far_dir/text
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    log "stage 2: finali data process"
+
+    utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+
+    sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
+    sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+
+    # remove space in text
+    for x in ${tgt}_Ali_far; do
+        cp data/${x}/text data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+        > data/${x}/text
+        rm data/${x}/text.org
+    done
+
+    log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    log "stage 3:process alimeeting far dir (single speaker by oracal time strap)"
+    cp -r $far_dir/* $far_single_speaker_dir 
+    mv $far_single_speaker_dir/textgrid.flist  $far_single_speaker_dir/textgrid_oldpath
+    paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
+    python local/process_textgrid_to_single_speaker_wav.py  --path $far_single_speaker_dir
+    
+    cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text    
+    utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+
+    ./utils/fix_data_dir.sh $far_single_speaker_dir 
+    utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+
+    # remove space in text
+    for x in ${tgt}_Ali_far_single_speaker; do
+        cp data/${x}/text data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+        > data/${x}/text
+        rm data/${x}/text.org
+    done
+    log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
new file mode 100755
index 0000000..8ece757
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
@@ -0,0 +1,235 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+class Segment(object):
+    def __init__(self, uttid, spkr, stime, etime, text):
+        self.uttid = uttid
+        self.spkr = spkr
+        self.spkr_all = uttid+"-"+spkr
+        self.stime = round(stime, 2)
+        self.etime = round(etime, 2)
+        self.text = text
+        self.spk_text = {uttid+"-"+spkr: text}
+
+    def change_stime(self, time):
+        self.stime = time
+
+    def change_etime(self, time):
+        self.etime = time
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="process the textgrid files")
+    parser.add_argument("--path", type=str, required=True, help="Data path")
+    parser.add_argument(
+        "--no-overlap",
+        type=strtobool,
+        default=False,
+        help="Whether to ignore the overlapping utterances.",
+    )
+    parser.add_argument(
+        "--max_length",
+        default=100000,
+        type=float,
+        help="overlap speech max time,if longger than max length should cut",
+    )
+    parser.add_argument(
+        "--overlap_length",
+        default=1,
+        type=float,
+        help="if length longer than max length, speech overlength shorter, is cut",
+    )
+    parser.add_argument(
+        "--mars",
+        type=strtobool,
+        default=False,
+        help="Whether to process mars data set.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def preposs_overlap(segments,max_length,overlap_length):
+    new_segments = []
+    # init a helper list to store all overlap segments
+    tmp_segments = segments[0]
+    min_stime = segments[0].stime
+    max_etime = segments[0].etime
+    overlap_length_big = 1.5
+    max_length_big = 15
+    for i in range(1, len(segments)):
+        if segments[i].stime >= max_etime:
+            # doesn't overlap with preivous segments
+            new_segments.append(tmp_segments)
+            tmp_segments = segments[i]
+            min_stime = segments[i].stime
+            max_etime = segments[i].etime
+        else:
+            # overlap with previous segments
+            dur_time = max_etime - min_stime
+            if dur_time < max_length:
+                if min_stime > segments[i].stime:
+                    min_stime = segments[i].stime
+                if max_etime < segments[i].etime:
+                    max_etime = segments[i].etime
+                tmp_segments.stime = min_stime
+                tmp_segments.etime = max_etime
+                tmp_segments.text = tmp_segments.text + "src" + segments[i].text
+                spk_name =segments[i].uttid +"-" + segments[i].spkr
+                if spk_name in tmp_segments.spk_text:
+                    tmp_segments.spk_text[spk_name] += segments[i].text 
+                else:
+                    tmp_segments.spk_text[spk_name] = segments[i].text
+                tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
+            else:
+                overlap_time = max_etime - segments[i].stime 
+                if dur_time < max_length_big:
+                    overlap_length_option = overlap_length
+                else:
+                    overlap_length_option = overlap_length_big
+                if overlap_time > overlap_length_option:
+                    if min_stime > segments[i].stime:
+                        min_stime = segments[i].stime
+                    if max_etime < segments[i].etime:
+                        max_etime = segments[i].etime
+                    tmp_segments.stime = min_stime
+                    tmp_segments.etime = max_etime
+                    tmp_segments.text = tmp_segments.text + "src" + segments[i].text
+                    spk_name =segments[i].uttid +"-" + segments[i].spkr
+                    if spk_name in tmp_segments.spk_text:
+                        tmp_segments.spk_text[spk_name] += segments[i].text 
+                    else:
+                        tmp_segments.spk_text[spk_name] = segments[i].text
+                    tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
+                else:
+                    new_segments.append(tmp_segments)
+                    tmp_segments = segments[i]
+                    min_stime = segments[i].stime
+                    max_etime = segments[i].etime
+                    
+    return new_segments
+
+def filter_overlap(segments):
+    new_segments = []
+    # init a helper list to store all overlap segments
+    tmp_segments = [segments[0]]
+    min_stime = segments[0].stime
+    max_etime = segments[0].etime
+
+    for i in range(1, len(segments)):
+        if segments[i].stime >= max_etime:
+            # doesn't overlap with preivous segments
+            if len(tmp_segments) == 1:
+                new_segments.append(tmp_segments[0])
+            # TODO: for multi-spkr asr, we can reset the stime/etime to
+            # min_stime/max_etime for generating a max length mixutre speech
+            tmp_segments = [segments[i]]
+            min_stime = segments[i].stime
+            max_etime = segments[i].etime
+        else:
+            # overlap with previous segments
+            tmp_segments.append(segments[i])
+            if min_stime > segments[i].stime:
+                min_stime = segments[i].stime
+            if max_etime < segments[i].etime:
+                max_etime = segments[i].etime
+
+    return new_segments
+
+
+def main(args):
+    wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
+    textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+
+    # get the path of textgrid file for each utterance
+    utt2textgrid = {}
+    for line in textgrid_flist:
+        path = Path(line.strip())
+        uttid = path.stem
+        utt2textgrid[uttid] = path
+
+    # parse the textgrid file for each utterance
+    all_segments = []
+    for line in wav_scp:
+        uttid = line.strip().split(" ")[0]
+        uttid_part=uttid
+        if args.mars == True:
+            uttid_list = uttid.split("_")
+            uttid_part= uttid_list[0]+"_"+uttid_list[1]
+        if uttid_part not in utt2textgrid:
+            print("%s doesn't have transcription" % uttid)
+            continue
+
+        segments = []
+        tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
+        for i in range(tg.__len__()):
+            for j in range(tg[i].__len__()):
+                if tg[i][j].mark:
+                    segments.append(
+                        Segment(
+                            uttid,
+                            tg[i].name,
+                            tg[i][j].minTime,
+                            tg[i][j].maxTime,
+                            tg[i][j].mark.strip(),
+                        )
+                    )
+
+        segments = sorted(segments, key=lambda x: x.stime)
+
+        if args.no_overlap:
+            segments = filter_overlap(segments)
+        else:
+            segments = preposs_overlap(segments,args.max_length,args.overlap_length)
+        all_segments += segments
+
+    wav_scp.close()
+    textgrid_flist.close()
+
+    segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
+    utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
+    text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
+    utt2spk_file_fifo = codecs.open(Path(args.path) / "utt2spk_all_fifo", "w", "utf-8")
+
+    for i in range(len(all_segments)):
+        utt_name = "%s-%s-%07d-%07d" % (
+            all_segments[i].uttid,
+            all_segments[i].spkr,
+            all_segments[i].stime * 100,
+            all_segments[i].etime * 100,
+        )
+
+        segments_file.write(
+            "%s %s %.2f %.2f\n"
+            % (
+                utt_name,
+                all_segments[i].uttid,
+                all_segments[i].stime,
+                all_segments[i].etime,
+            )
+        )
+        utt2spk_file.write(
+            "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
+        )
+        utt2spk_file_fifo.write(
+            "%s %s\n" % (utt_name,  all_segments[i].spkr_all)
+        )
+        text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
+
+    segments_file.close()
+    utt2spk_file.close()
+    text_file.close()
+    utt2spk_file_fifo.close()
+
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
new file mode 100755
index 0000000..81c1965
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+class Segment(object):
+    def __init__(self, uttid, spkr, stime, etime, text):
+        self.uttid = uttid
+        self.spkr = spkr
+        self.stime = round(stime, 2)
+        self.etime = round(etime, 2)
+        self.text = text
+
+    def change_stime(self, time):
+        self.stime = time
+
+    def change_etime(self, time):
+        self.etime = time
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="process the textgrid files")
+    parser.add_argument("--path", type=str, required=True, help="Data path")
+    parser.add_argument(
+        "--no-overlap",
+        type=strtobool,
+        default=False,
+        help="Whether to ignore the overlapping utterances.",
+    )
+    parser.add_argument(
+        "--mars",
+        type=strtobool,
+        default=False,
+        help="Whether to process mars data set.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def filter_overlap(segments):
+    new_segments = []
+    # init a helper list to store all overlap segments
+    tmp_segments = [segments[0]]
+    min_stime = segments[0].stime
+    max_etime = segments[0].etime
+
+    for i in range(1, len(segments)):
+        if segments[i].stime >= max_etime:
+            # doesn't overlap with preivous segments
+            if len(tmp_segments) == 1:
+                new_segments.append(tmp_segments[0])
+            # TODO: for multi-spkr asr, we can reset the stime/etime to
+            # min_stime/max_etime for generating a max length mixutre speech
+            tmp_segments = [segments[i]]
+            min_stime = segments[i].stime
+            max_etime = segments[i].etime
+        else:
+            # overlap with previous segments
+            tmp_segments.append(segments[i])
+            if min_stime > segments[i].stime:
+                min_stime = segments[i].stime
+            if max_etime < segments[i].etime:
+                max_etime = segments[i].etime
+
+    return new_segments
+
+
+def main(args):
+    wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
+    textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+
+    # get the path of textgrid file for each utterance
+    utt2textgrid = {}
+    for line in textgrid_flist:
+        path = Path(line.strip())
+        uttid = path.stem
+        utt2textgrid[uttid] = path
+
+    # parse the textgrid file for each utterance
+    all_segments = []
+    for line in wav_scp:
+        uttid = line.strip().split(" ")[0]
+        uttid_part=uttid
+        if args.mars == True:
+            uttid_list = uttid.split("_")
+            uttid_part= uttid_list[0]+"_"+uttid_list[1]
+        if uttid_part not in utt2textgrid:
+            print("%s doesn't have transcription" % uttid)
+            continue
+        #pdb.set_trace()
+        segments = []
+        try:
+            tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
+        except:
+            pdb.set_trace()
+        for i in range(tg.__len__()):
+            for j in range(tg[i].__len__()):
+                if tg[i][j].mark:
+                    segments.append(
+                        Segment(
+                            uttid,
+                            tg[i].name,
+                            tg[i][j].minTime,
+                            tg[i][j].maxTime,
+                            tg[i][j].mark.strip(),
+                        )
+                    )
+
+        segments = sorted(segments, key=lambda x: x.stime)
+
+        if args.no_overlap:
+            segments = filter_overlap(segments)
+
+        all_segments += segments
+
+    wav_scp.close()
+    textgrid_flist.close()
+
+    segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
+    utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
+    text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
+
+    for i in range(len(all_segments)):
+        utt_name = "%s-%s-%07d-%07d" % (
+            all_segments[i].uttid,
+            all_segments[i].spkr,
+            all_segments[i].stime * 100,
+            all_segments[i].etime * 100,
+        )
+
+        segments_file.write(
+            "%s %s %.2f %.2f\n"
+            % (
+                utt_name,
+                all_segments[i].uttid,
+                all_segments[i].stime,
+                all_segments[i].etime,
+            )
+        )
+        utt2spk_file.write(
+            "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
+        )
+        text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
+
+    segments_file.close()
+    utt2spk_file.close()
+    text_file.close()
+
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa-asr/local/compute_cpcer.py
new file mode 100644
index 0000000..f4d4a79
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/compute_cpcer.py
@@ -0,0 +1,91 @@
+import editdistance
+import sys
+import os
+from itertools import permutations
+
+
+def load_transcripts(file_path):
+    trans_list = []
+    for one_line in open(file_path, "rt"):
+        meeting_id, trans = one_line.strip().split(" ")
+        trans_list.append((meeting_id.strip(), trans.strip()))
+
+    return trans_list
+
+def calc_spk_trans(trans):
+    spk_trans_ = [x.strip() for x in trans.split("$")]
+    spk_trans = []
+    for i in range(len(spk_trans_)):
+        spk_trans.append((str(i), spk_trans_[i]))
+    return spk_trans
+
+def calc_cer(ref_trans, hyp_trans):
+    ref_spk_trans = calc_spk_trans(ref_trans)
+    hyp_spk_trans = calc_spk_trans(hyp_trans)
+    ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
+    num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
+    ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
+    hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
+
+    errors, counts, permutes = [], [], []
+    min_error = 0
+    cost_dict = {}
+    for perm in permutations(range(num_spk)):
+        flag = True
+        p_err, p_count = 0, 0
+        for idx, p in enumerate(perm):
+            if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
+                flag = False
+                break
+            cost_key = "{}-{}".format(idx, p)
+            if cost_key in cost_dict:
+                _e = cost_dict[cost_key]
+            else:
+                _e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
+                cost_dict[cost_key] = _e
+            if _e > min_error > 0:
+                flag = False
+                break
+            p_err += _e
+            p_count += len(ref_spk_trans[idx][1])
+
+        if flag:
+            if p_err < min_error or min_error == 0:
+                min_error = p_err
+
+            errors.append(p_err)
+            counts.append(p_count)
+            permutes.append(perm)
+
+    sd_cer = [(err, cnt, err/cnt, permute)
+              for err, cnt, permute in zip(errors, counts, permutes)]
+    # import ipdb;ipdb.set_trace()
+    best_rst = min(sd_cer, key=lambda x: x[2])
+
+    return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
+
+
+def main():
+    ref=sys.argv[1]
+    hyp=sys.argv[2]
+    result_path=sys.argv[3]
+    ref_list = load_transcripts(ref)
+    hyp_list = load_transcripts(hyp)
+    result_file = open(result_path,'w')
+    error, count = 0, 0
+    for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
+        assert ref_id == hyp_id
+        mid = ref_id
+        dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
+        error, count = error + dist, count + length
+        result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
+
+        # print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
+
+    result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
+    result_file.close()
+    # print("Sum/Avg: {:.2f}".format(error / count * 100.0))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/egs/alimeeting/sa-asr/local/compute_wer.py b/egs/alimeeting/sa-asr/local/compute_wer.py
new file mode 100755
index 0000000..349a3f6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/compute_wer.py
@@ -0,0 +1,157 @@
+import os
+import numpy as np
+import sys
+
+def compute_wer(ref_file,
+                hyp_file,
+                cer_detail_file):
+    rst = {
+        'Wrd': 0,
+        'Corr': 0,
+        'Ins': 0,
+        'Del': 0,
+        'Sub': 0,
+        'Snt': 0,
+        'Err': 0.0,
+        'S.Err': 0.0,
+        'wrong_words': 0,
+        'wrong_sentences': 0
+    }
+
+    hyp_dict = {}
+    ref_dict = {}
+    with open(hyp_file, 'r') as hyp_reader:
+        for line in hyp_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            hyp_dict[key] = value
+    with open(ref_file, 'r') as ref_reader:
+        for line in ref_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            ref_dict[key] = value
+
+    cer_detail_writer = open(cer_detail_file, 'w')
+    for hyp_key in hyp_dict:
+        if hyp_key in ref_dict:
+           out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
+           rst['Wrd'] += out_item['nwords']
+           rst['Corr'] += out_item['cor']
+           rst['wrong_words'] += out_item['wrong']
+           rst['Ins'] += out_item['ins']
+           rst['Del'] += out_item['del']
+           rst['Sub'] += out_item['sub']
+           rst['Snt'] += 1
+           if out_item['wrong'] > 0:
+               rst['wrong_sentences'] += 1
+           cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
+           cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
+           cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
+
+    if rst['Wrd'] > 0:
+        rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
+    if rst['Snt'] > 0:
+        rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
+
+    cer_detail_writer.write('\n')
+    cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
+                            ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
+    cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
+    cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
+
+     
+def compute_wer_by_line(hyp,
+                        ref):
+    hyp = list(map(lambda x: x.lower(), hyp))
+    ref = list(map(lambda x: x.lower(), ref))
+
+    len_hyp = len(hyp)
+    len_ref = len(ref)
+
+    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
+
+    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
+
+    for i in range(len_hyp + 1):
+        cost_matrix[i][0] = i
+    for j in range(len_ref + 1):
+        cost_matrix[0][j] = j
+
+    for i in range(1, len_hyp + 1):
+        for j in range(1, len_ref + 1):
+            if hyp[i - 1] == ref[j - 1]:
+                cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
+            else:
+                substitution = cost_matrix[i - 1][j - 1] + 1
+                insertion = cost_matrix[i - 1][j] + 1
+                deletion = cost_matrix[i][j - 1] + 1
+
+                compare_val = [substitution, insertion, deletion]
+
+                min_val = min(compare_val)
+                operation_idx = compare_val.index(min_val) + 1
+                cost_matrix[i][j] = min_val
+                ops_matrix[i][j] = operation_idx
+
+    match_idx = []
+    i = len_hyp
+    j = len_ref
+    rst = {
+        'nwords': len_ref,
+        'cor': 0,
+        'wrong': 0,
+        'ins': 0,
+        'del': 0,
+        'sub': 0
+    }
+    while i >= 0 or j >= 0:
+        i_idx = max(0, i)
+        j_idx = max(0, j)
+
+        if ops_matrix[i_idx][j_idx] == 0:  # correct
+            if i - 1 >= 0 and j - 1 >= 0:
+                match_idx.append((j - 1, i - 1))
+                rst['cor'] += 1
+
+            i -= 1
+            j -= 1
+
+        elif ops_matrix[i_idx][j_idx] == 2:  # insert
+            i -= 1
+            rst['ins'] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 3:  # delete
+            j -= 1
+            rst['del'] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 1:  # substitute
+            i -= 1
+            j -= 1
+            rst['sub'] += 1
+
+        if i < 0 and j >= 0:
+            rst['del'] += 1
+        elif j < 0 and i >= 0:
+            rst['ins'] += 1
+
+    match_idx.reverse()
+    wrong_cnt = cost_matrix[len_hyp][len_ref]
+    rst['wrong'] = wrong_cnt
+
+    return rst
+
+def print_cer_detail(rst):
+    return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+            + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+            + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+            + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
+
+if __name__ == '__main__':
+    if len(sys.argv) != 4:
+        print("usage : python compute-wer.py test.ref test.hyp test.wer")
+        sys.exit(0)
+
+    ref_file = sys.argv[1]
+    hyp_file = sys.argv[2]
+    cer_detail_file = sys.argv[3]
+    compute_wer(ref_file, hyp_file, cer_detail_file)
diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa-asr/local/download_xvector_model.py
new file mode 100644
index 0000000..7da6559
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/download_xvector_model.py
@@ -0,0 +1,6 @@
+from modelscope.hub.snapshot_download import snapshot_download
+import sys
+
+
+cache_dir = sys.argv[1]
+model_dir = snapshot_download('damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch', cache_dir=cache_dir)
diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
new file mode 100644
index 0000000..e606162
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
@@ -0,0 +1,22 @@
+import sys
+if __name__=="__main__":
+    uttid_path=sys.argv[1]
+    src_path=sys.argv[2]
+    tgt_path=sys.argv[3]
+    uttid_file=open(uttid_path,'r')
+    uttid_line=uttid_file.readlines()
+    uttid_file.close()
+    ori_utt2spk_all_fifo_file=open(src_path+'/utt2spk_all_fifo','r')
+    ori_utt2spk_all_fifo_line=ori_utt2spk_all_fifo_file.readlines()
+    ori_utt2spk_all_fifo_file.close()
+    new_utt2spk_all_fifo_file=open(tgt_path+'/utt2spk_all_fifo','w')
+
+    uttid_list=[]
+    for line in uttid_line:
+        uttid_list.append(line.strip())
+    
+    for line in ori_utt2spk_all_fifo_line:
+        if line.strip().split(' ')[0] in uttid_list:
+            new_utt2spk_all_fifo_file.write(line)
+    
+    new_utt2spk_all_fifo_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
new file mode 100644
index 0000000..c37abf9
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
@@ -0,0 +1,167 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+import sys
+import os
+import soundfile
+from itertools import permutations
+from sklearn.metrics.pairwise import cosine_similarity
+from sklearn import cluster
+
+
+def custom_spectral_clustering(affinity, min_n_clusters=2, max_n_clusters=4, refine=True,
+    threshold=0.995, laplacian_type="graph_cut"):
+    if refine:
+        # Symmetrization
+        affinity = np.maximum(affinity, np.transpose(affinity))
+        # Diffusion
+        affinity = np.matmul(affinity, np.transpose(affinity))
+        # Row-wise max normalization
+        row_max = affinity.max(axis=1, keepdims=True)
+        affinity = affinity / row_max
+
+    # a) Construct S and set diagonal elements to 0
+    affinity = affinity - np.diag(np.diag(affinity))
+    # b) Compute Laplacian matrix L and perform normalization:
+    degree = np.diag(np.sum(affinity, axis=1))
+    laplacian = degree - affinity
+    if laplacian_type == "random_walk":
+        degree_norm = np.diag(1 / (np.diag(degree) + 1e-10))
+        laplacian_norm = degree_norm.dot(laplacian)
+    else:
+        degree_half = np.diag(degree) ** 0.5 + 1e-15
+        laplacian_norm = laplacian / degree_half[:, np.newaxis] / degree_half
+
+    # c) Compute eigenvalues and eigenvectors of L_norm
+    eigenvalues, eigenvectors = np.linalg.eig(laplacian_norm)
+    eigenvalues = eigenvalues.real
+    eigenvectors = eigenvectors.real
+    index_array = np.argsort(eigenvalues)
+    eigenvalues = eigenvalues[index_array]
+    eigenvectors = eigenvectors[:, index_array]
+
+    # d) Compute the number of clusters k
+    k = min_n_clusters
+    for k in range(min_n_clusters, max_n_clusters + 1):
+        if eigenvalues[k] > threshold:
+            break
+    k = max(k, min_n_clusters)
+    spectral_embeddings = eigenvectors[:, :k]
+    # print(mid, k, eigenvalues[:10])
+
+    spectral_embeddings = spectral_embeddings / np.linalg.norm(spectral_embeddings, axis=1, ord=2, keepdims=True)
+    solver = cluster.KMeans(n_clusters=k, max_iter=1000, random_state=42)
+    solver.fit(spectral_embeddings)
+    return solver.labels_
+
+
+if __name__ == "__main__":
+    path = sys.argv[1] # dump2/raw/Eval_Ali_far
+    raw_path = sys.argv[2] # data/local/Eval_Ali_far
+    threshold = float(sys.argv[3]) # 0.996
+    sv_threshold = float(sys.argv[4]) # 0.815
+    wav_scp_file = open(path+'/wav.scp', 'r')
+    wav_scp = wav_scp_file.readlines()
+    wav_scp_file.close()
+    raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+    raw_meeting_scp = raw_meeting_scp_file.readlines()
+    raw_meeting_scp_file.close()
+    segments_scp_file = open(raw_path + '/segments', 'r')
+    segments_scp = segments_scp_file.readlines()
+    segments_scp_file.close()
+
+    segments_map = {}
+    for line in segments_scp:
+        line_list = line.strip().split(' ')
+        meeting = line_list[1]
+        seg = (float(line_list[-2]), float(line_list[-1]))
+        if meeting not in segments_map.keys():
+            segments_map[meeting] = [seg]
+        else:
+            segments_map[meeting].append(seg)
+    
+    inference_sv_pipline = pipeline(
+        task=Tasks.speaker_verification,
+        model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+    )
+
+    chunk_len = int(1.5*16000) # 1.5 seconds
+    hop_len = int(0.75*16000) # 0.75 seconds
+
+    os.system("mkdir -p " + path + "/cluster_profile_infer")
+    cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
+    meeting_map = {}
+    for line in raw_meeting_scp:
+        meeting = line.strip().split('\t')[0]
+        wav_path = line.strip().split('\t')[1]
+        wav = soundfile.read(wav_path)[0]
+        # take the first channel
+        if wav.ndim == 2:
+            wav=wav[:, 0]
+        # gen_seg_embedding
+        segments_list = segments_map[meeting]
+
+        # import ipdb;ipdb.set_trace()
+        all_seg_embedding_list = []
+        for seg in segments_list:
+            wav_seg = wav[int(seg[0] * 16000): int(seg[1] * 16000)]
+            wav_seg_len = wav_seg.shape[0]
+            i = 0
+            while i < wav_seg_len:
+                if i + chunk_len < wav_seg_len:
+                    cur_wav_chunk = wav_seg[i: i+chunk_len]
+                else:
+                    cur_wav_chunk=wav_seg[i: ]
+                # chunks under 0.2s are ignored
+                if cur_wav_chunk.shape[0] >= 0.2 * 16000:
+                    cur_chunk_embedding = inference_sv_pipline(audio_in=cur_wav_chunk)["spk_embedding"]
+                    all_seg_embedding_list.append(cur_chunk_embedding)
+                i += hop_len
+        all_seg_embedding = np.vstack(all_seg_embedding_list)
+        # all_seg_embedding (n, dim)
+
+        # compute affinity
+        affinity=cosine_similarity(all_seg_embedding)
+
+        affinity = np.maximum(affinity - sv_threshold, 0.0001) / (affinity.max() - sv_threshold)
+
+        # clustering
+        labels = custom_spectral_clustering(
+            affinity=affinity,
+            min_n_clusters=2,
+            max_n_clusters=4,
+            refine=True,
+            threshold=threshold,
+            laplacian_type="graph_cut")
+       
+
+        cluster_dict={}
+        for j in range(labels.shape[0]):
+            if labels[j] not in cluster_dict.keys():
+                cluster_dict[labels[j]] = np.atleast_2d(all_seg_embedding[j])
+            else:
+                cluster_dict[labels[j]] = np.concatenate((cluster_dict[labels[j]], np.atleast_2d(all_seg_embedding[j])))
+        
+        emb_list = []
+        # get cluster center
+        for k in cluster_dict.keys():
+            cluster_dict[k] = np.mean(cluster_dict[k], axis=0)
+            emb_list.append(cluster_dict[k])
+
+        spk_num = len(emb_list)
+        profile_for_infer = np.vstack(emb_list)
+        # save profile for each meeting
+        np.save(path + '/cluster_profile_infer/' + meeting + '.npy', profile_for_infer)
+        meeting_map[meeting] = (path + '/cluster_profile_infer/' + meeting + '.npy', spk_num)
+        cluster_spk_num_file.write(meeting + ' ' + str(spk_num) + '\n')
+        cluster_spk_num_file.flush()
+    
+    cluster_spk_num_file.close()
+
+    profile_scp = open(path + "/cluster_profile_infer.scp", 'w')
+    for line in wav_scp:
+        uttid = line.strip().split(' ')[0]
+        meeting = uttid.split('-')[0]
+        profile_scp.write(uttid + ' ' + meeting_map[meeting][0] + '\n')
+        profile_scp.flush()
+    profile_scp.close()
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
new file mode 100644
index 0000000..18286b4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
@@ -0,0 +1,70 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+import sys
+import os
+import soundfile
+
+
+if __name__=="__main__":
+    path = sys.argv[1] # dump2/raw/Eval_Ali_far
+    raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
+    raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+    raw_meeting_scp = raw_meeting_scp_file.readlines()
+    raw_meeting_scp_file.close()
+    segments_scp_file = open(raw_path + '/segments', 'r')
+    segments_scp = segments_scp_file.readlines()
+    segments_scp_file.close()
+
+    oracle_emb_dir = path + '/oracle_embedding/'
+    os.system("mkdir -p " + oracle_emb_dir)
+    oracle_emb_scp_file = open(path+'/oracle_embedding.scp', 'w')
+
+    raw_wav_map = {}
+    for line in raw_meeting_scp:
+        meeting = line.strip().split('\t')[0]
+        wav_path = line.strip().split('\t')[1]
+        raw_wav_map[meeting] = wav_path
+    
+    spk_map = {}
+    for line in segments_scp:
+        line_list = line.strip().split(' ')
+        meeting = line_list[1]
+        spk_id = line_list[0].split('_')[3]
+        spk = meeting + '_' + spk_id
+        time_start = float(line_list[-2])
+        time_end = float(line_list[-1])
+        if time_end - time_start > 0.5:
+            if spk not in spk_map.keys():
+                spk_map[spk] = [(int(time_start * 16000), int(time_end * 16000))]
+            else:
+                spk_map[spk].append((int(time_start * 16000), int(time_end * 16000)))
+    
+    inference_sv_pipline = pipeline(
+        task=Tasks.speaker_verification,
+        model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+    )
+
+    for spk in spk_map.keys():
+        meeting = spk.split('_SPK')[0]
+        wav_path = raw_wav_map[meeting]
+        wav = soundfile.read(wav_path)[0]
+        # take the first channel
+        if wav.ndim == 2:
+            wav = wav[:, 0]
+        all_seg_embedding_list=[]
+        # import ipdb;ipdb.set_trace()
+        for seg_time in spk_map[spk]:
+            if seg_time[0] < wav.shape[0] - 0.5 * 16000:
+                if seg_time[1] > wav.shape[0]:
+                    cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: ])["spk_embedding"]
+                else:
+                    cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: seg_time[1]])["spk_embedding"]
+                all_seg_embedding_list.append(cur_seg_embedding)
+        all_seg_embedding = np.vstack(all_seg_embedding_list)
+        spk_embedding = np.mean(all_seg_embedding, axis=0)
+        np.save(oracle_emb_dir + spk + '.npy', spk_embedding)
+        oracle_emb_scp_file.write(spk + ' ' + oracle_emb_dir + spk + '.npy' + '\n')
+        oracle_emb_scp_file.flush()
+    
+    oracle_emb_scp_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
new file mode 100644
index 0000000..f44fcd4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
@@ -0,0 +1,59 @@
+import random
+import numpy as np
+import os
+import sys
+
+
+if __name__=="__main__":
+    path = sys.argv[1] # dump2/raw/Eval_Ali_far
+    wav_scp_file = open(path+"/wav.scp", 'r')
+    wav_scp = wav_scp_file.readlines()
+    wav_scp_file.close()
+    spk2id_file = open(path + "/spk2id", 'r')
+    spk2id = spk2id_file.readlines()
+    spk2id_file.close()
+    embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
+    embedding_scp = embedding_scp_file.readlines()
+    embedding_scp_file.close()
+
+    embedding_map = {}
+    for line in embedding_scp:
+        spk = line.strip().split(' ')[0]
+        if spk not in embedding_map.keys():
+            emb=np.load(line.strip().split(' ')[1])
+            embedding_map[spk] = emb
+    
+    meeting_map_tmp = {}
+    global_spk_list = []
+    for line in spk2id:
+        line_list = line.strip().split(' ')
+        meeting = line_list[0].split('-')[0]
+        spk_id = line_list[0].split('-')[-1].split('_')[-1]
+        spk = meeting + '_' + spk_id
+        global_spk_list.append(spk)
+        if meeting in meeting_map_tmp.keys():
+            meeting_map_tmp[meeting].append(spk)
+        else:
+            meeting_map_tmp[meeting] = [spk]
+    
+    meeting_map = {}
+    os.system('mkdir -p ' + path + '/oracle_profile_nopadding')
+    for meeting in meeting_map_tmp.keys():
+        emb_list = []
+        for i in range(len(meeting_map_tmp[meeting])):
+            spk = meeting_map_tmp[meeting][i]
+            emb_list.append(embedding_map[spk])
+        profile = np.vstack(emb_list)
+        np.save(path + '/oracle_profile_nopadding/' + meeting + '.npy', profile)
+        meeting_map[meeting] = path + '/oracle_profile_nopadding/' + meeting + '.npy'
+    
+    profile_scp = open(path + '/oracle_profile_nopadding.scp', 'w')
+    profile_map_scp = open(path + '/oracle_profile_nopadding_spk_list', 'w')
+
+    for line in wav_scp:
+        uttid = line.strip().split(' ')[0]
+        meeting = uttid.split('-')[0]
+        profile_scp.write(uttid + ' ' + meeting_map[meeting] + '\n')
+        profile_map_scp.write(uttid + ' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
+    profile_scp.close()
+    profile_map_scp.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
new file mode 100644
index 0000000..b70a32a
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
@@ -0,0 +1,68 @@
+import random
+import numpy as np
+import os
+import sys
+
+
+if __name__=="__main__":
+    path = sys.argv[1] # dump2/raw/Train_Ali_far
+    wav_scp_file = open(path+"/wav.scp", 'r')
+    wav_scp = wav_scp_file.readlines()
+    wav_scp_file.close()
+    spk2id_file = open(path+"/spk2id", 'r')
+    spk2id = spk2id_file.readlines()
+    spk2id_file.close()
+    embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
+    embedding_scp = embedding_scp_file.readlines()
+    embedding_scp_file.close()
+
+    embedding_map = {}
+    for line in embedding_scp:
+        spk = line.strip().split(' ')[0]
+        if spk not in embedding_map.keys():
+            emb = np.load(line.strip().split(' ')[1])
+            embedding_map[spk] = emb
+    
+    meeting_map_tmp = {}
+    global_spk_list = []
+    for line in spk2id:
+        line_list = line.strip().split(' ')
+        meeting = line_list[0].split('-')[0]
+        spk_id = line_list[0].split('-')[-1].split('_')[-1]
+        spk = meeting+'_' + spk_id
+        global_spk_list.append(spk)
+        if meeting in meeting_map_tmp.keys():
+            meeting_map_tmp[meeting].append(spk)
+        else:
+            meeting_map_tmp[meeting] = [spk]
+    
+    for meeting in meeting_map_tmp.keys():
+        num = len(meeting_map_tmp[meeting])
+        if num < 4:
+            global_spk_list_tmp = global_spk_list[: ]
+            for spk in meeting_map_tmp[meeting]:
+                global_spk_list_tmp.remove(spk)
+                padding_spk = random.sample(global_spk_list_tmp, 4 - num)
+                meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk
+    
+    meeting_map = {}
+    os.system('mkdir -p ' + path + '/oracle_profile_padding')
+    for meeting in meeting_map_tmp.keys():
+        emb_list = []
+        for i in range(len(meeting_map_tmp[meeting])):
+            spk = meeting_map_tmp[meeting][i]
+            emb_list.append(embedding_map[spk])
+        profile = np.vstack(emb_list)
+        np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile)
+        meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy'
+    
+    profile_scp = open(path + '/oracle_profile_padding.scp', 'w')
+    profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w')
+
+    for line in wav_scp:
+        uttid = line.strip().split(' ')[0]
+        meeting = uttid.split('-')[0]
+        profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n')
+        profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
+    profile_scp.close()
+    profile_map_scp.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/proce_text.py b/egs/alimeeting/sa-asr/local/proce_text.py
new file mode 100755
index 0000000..e56cc0f
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/proce_text.py
@@ -0,0 +1,32 @@
+
+import sys
+import re
+
+in_f = sys.argv[1]
+out_f = sys.argv[2]
+
+
+with open(in_f, "r", encoding="utf-8") as f:
+  lines = f.readlines()
+
+with open(out_f, "w", encoding="utf-8") as f:
+  for line in lines:
+    outs = line.strip().split(" ", 1)
+    if len(outs) == 2:
+      idx, text = outs
+      text = re.sub("</s>", "", text)
+      text = re.sub("<s>", "", text)
+      text = re.sub("@@", "", text)
+      text = re.sub("@", "", text)
+      text = re.sub("<unk>", "", text)
+      text = re.sub(" ", "", text)
+      text = re.sub("\$", "", text)
+      text = text.lower()
+    else:
+      idx = outs[0]
+      text = " "
+
+    text = [x for x in text]
+    text = " ".join(text)
+    out = "{} {}\n".format(idx, text)
+    f.write(out)
diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
new file mode 100755
index 0000000..d900bb1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+def get_args():
+    parser = argparse.ArgumentParser(description="process the textgrid files")
+    parser.add_argument("--path", type=str, required=True, help="Data path")
+    args = parser.parse_args()
+    return args
+
+class Segment(object):
+    def __init__(self, uttid, text):
+        self.uttid = uttid
+        self.text = text
+
+def main(args):
+    text = codecs.open(Path(args.path) / "text", "r", "utf-8")
+    spk2utt = codecs.open(Path(args.path) / "spk2utt", "r", "utf-8")
+    utt2spk = codecs.open(Path(args.path) / "utt2spk_all_fifo", "r", "utf-8")   
+    spk2id = codecs.open(Path(args.path) / "spk2id", "w", "utf-8")
+    
+    spkid_map = {}
+    meetingid_map = {}
+    for line in spk2utt:
+        spkid = line.strip().split(" ")[0]
+        meeting_id_list = spkid.split("_")[:3]
+        meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
+        if meeting_id not in meetingid_map:
+            meetingid_map[meeting_id] = 1     
+        else:
+            meetingid_map[meeting_id] += 1
+        spkid_map[spkid] = meetingid_map[meeting_id]
+        spk2id.write("%s %s\n" % (spkid, meetingid_map[meeting_id]))
+    
+    utt2spklist = {}
+    for line in utt2spk:
+        uttid = line.strip().split(" ")[0]
+        spkid = line.strip().split(" ")[1]
+        spklist = spkid.split("$")
+        tmp = []
+        for index in range(len(spklist)):
+            tmp.append(spkid_map[spklist[index]])
+        utt2spklist[uttid] = tmp
+    # parse the textgrid file for each utterance
+    all_segments = []
+    for line in text:
+        uttid = line.strip().split(" ")[0]
+        context = line.strip().split(" ")[1]
+        spklist = utt2spklist[uttid]
+        length_text = len(context)
+        cnt = 0
+        tmp_text = ""
+        for index in range(length_text):
+            if context[index] != "$":
+                tmp_text += str(spklist[cnt])
+            else:
+                tmp_text += "$"
+                cnt += 1
+        tmp_seg = Segment(uttid,tmp_text)
+        all_segments.append(tmp_seg)
+
+    text.close()
+    utt2spk.close()
+    spk2utt.close()
+    spk2id.close()
+
+    text_id = codecs.open(Path(args.path) / "text_id", "w", "utf-8")
+
+    for i in range(len(all_segments)):
+        uttid_tmp = all_segments[i].uttid
+        text_tmp = all_segments[i].text
+        
+        text_id.write("%s %s\n" % (uttid_tmp, text_tmp))
+
+    text_id.close()
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa-asr/local/process_text_id.py
new file mode 100644
index 0000000..0a9506e
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_text_id.py
@@ -0,0 +1,24 @@
+import sys
+if __name__=="__main__":
+    path=sys.argv[1]
+
+    text_id_old_file=open(path+"/text_id",'r')
+    text_id_old=text_id_old_file.readlines()
+    text_id_old_file.close()
+    
+    text_id=open(path+"/text_id_train",'w')
+    for line in text_id_old:
+        uttid=line.strip().split(' ')[0]
+        old_id=line.strip().split(' ')[1]
+        pre_id='0'
+        new_id_list=[]
+        for i in old_id:
+            if i == '$':
+                new_id_list.append(pre_id)
+            else:
+                new_id_list.append(str(int(i)-1))
+                pre_id=str(int(i)-1)
+        new_id_list.append(pre_id)
+        new_id=' '.join(new_id_list)
+        text_id.write(uttid+' '+new_id+'\n')
+    text_id.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py
new file mode 100644
index 0000000..f15d509
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py
@@ -0,0 +1,55 @@
+import sys
+
+
+if __name__ == "__main__":
+    path=sys.argv[1]
+    text_scp_file = open(path + '/text', 'r')
+    text_scp = text_scp_file.readlines()
+    text_scp_file.close()
+    text_id_scp_file = open(path + '/text_id', 'r')
+    text_id_scp = text_id_scp_file.readlines()
+    text_id_scp_file.close()
+    text_spk_merge_file = open(path + '/text_spk_merge', 'w')
+    assert len(text_scp) == len(text_id_scp)
+
+    meeting_map = {} # {meeting_id: [(start_time, text, text_id), (start_time, text, text_id), ...]}
+    for i in range(len(text_scp)):
+        text_line = text_scp[i].strip().split(' ')
+        text_id_line = text_id_scp[i].strip().split(' ')
+        assert text_line[0] == text_id_line[0]
+        if len(text_line) > 1:
+            uttid = text_line[0]
+            text = text_line[1]
+            text_id = text_id_line[1]
+            meeting_id = uttid.split('-')[0]
+            start_time = int(uttid.split('-')[-2])
+            if meeting_id not in meeting_map:
+                meeting_map[meeting_id] = [(start_time,text,text_id)]
+            else:
+                meeting_map[meeting_id].append((start_time,text,text_id))
+            
+    for meeting_id in sorted(meeting_map.keys()):
+        cur_meeting_list = sorted(meeting_map[meeting_id], key=lambda x: x[0])
+        text_spk_merge_map = {} #{1: text1, 2: text2, ...}
+        for cur_utt in cur_meeting_list:
+            cur_text = cur_utt[1]
+            cur_text_id = cur_utt[2]
+            assert len(cur_text)==len(cur_text_id)
+            if len(cur_text) != 0:
+                cur_text_split = cur_text.split('$')
+                cur_text_id_split = cur_text_id.split('$')
+                assert len(cur_text_split) == len(cur_text_id_split)
+                for i in range(len(cur_text_split)):
+                    if len(cur_text_split[i]) != 0:
+                        spk_id = int(cur_text_id_split[i][0])
+                        if spk_id not in text_spk_merge_map.keys():
+                            text_spk_merge_map[spk_id] = cur_text_split[i]
+                        else:
+                            text_spk_merge_map[spk_id] += cur_text_split[i]
+        text_spk_merge_list = []
+        for spk_id in sorted(text_spk_merge_map.keys()):
+            text_spk_merge_list.append(text_spk_merge_map[spk_id])
+        text_spk_merge_file.write(meeting_id + ' ' + '$'.join(text_spk_merge_list) + '\n')
+        text_spk_merge_file.flush()
+    
+    text_spk_merge_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
new file mode 100755
index 0000000..fdf2460
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+import numpy as np
+import sys
+import math
+
+
+class Segment(object):
+    def __init__(self, uttid, spkr, stime, etime, text):
+        self.uttid = uttid
+        self.spkr = spkr
+        self.stime = round(stime, 2)
+        self.etime = round(etime, 2)
+        self.text = text
+
+    def change_stime(self, time):
+        self.stime = time
+
+    def change_etime(self, time):
+        self.etime = time
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="process the textgrid files")
+    parser.add_argument("--path", type=str, required=True, help="Data path")
+    args = parser.parse_args()
+    return args
+
+
+
+def main(args):
+    textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+    segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8")
+    utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8")
+
+    # get the path of textgrid file for each utterance
+    for line in textgrid_flist:
+        line_array = line.strip().split(" ")
+        path = Path(line_array[1])
+        uttid = line_array[0]
+
+        try:
+            tg = textgrid.TextGrid.fromFile(path)
+        except:
+            pdb.set_trace()
+        num_spk = tg.__len__()
+        spk2textgrid = {}
+        spk2weight = {}
+        weight2spk = {}
+        cnt = 2
+        xmax = 0
+        for i in range(tg.__len__()):
+            spk_name = tg[i].name
+            if spk_name not in spk2weight:
+                spk2weight[spk_name] = cnt
+                weight2spk[cnt] = spk_name
+                cnt = cnt * 2
+            segments = []
+            for j in range(tg[i].__len__()):
+                if tg[i][j].mark:
+                    if xmax < tg[i][j].maxTime:
+                        xmax = tg[i][j].maxTime
+                    segments.append(
+                        Segment(
+                            uttid,
+                            tg[i].name,
+                            tg[i][j].minTime,
+                            tg[i][j].maxTime,
+                            tg[i][j].mark.strip(),
+                        )
+                    )
+            segments = sorted(segments, key=lambda x: x.stime)
+            spk2textgrid[spk_name] = segments
+        olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32)
+        for spkid in spk2weight.keys():
+            weight = spk2weight[spkid]
+            segments = spk2textgrid[spkid]
+            idx = int(math.log2(weight) )- 1
+            for i in range(len(segments)):
+                stime = segments[i].stime
+                etime = segments[i].etime
+                olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight
+        sum_label = olp_label.sum(axis=0)
+        stime = 0
+        pre_value = 0
+        for pos in range(sum_label.shape[0]):
+            if sum_label[pos] in weight2spk:
+                if pre_value in weight2spk:
+                    if sum_label[pos] != pre_value:    
+                        spkids = weight2spk[pre_value]
+                        spkid_array = spkids.split("_")
+                        spkid = spkid_array[-1]
+                        #spkid = uttid+spkid 
+                        if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
+                            segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
+                            utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
+                        stime = pos
+                        pre_value = sum_label[pos]
+                else:
+                    stime = pos
+                    pre_value = sum_label[pos]
+            else:
+                if pre_value in weight2spk:
+                    spkids = weight2spk[pre_value]
+                    spkid_array = spkids.split("_")
+                    spkid = spkid_array[-1]
+                    #spkid = uttid+spkid 
+                    if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
+                        segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
+                        utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
+                    stime = pos
+                    pre_value = sum_label[pos]
+    textgrid_flist.close()
+    segment_file.close()
+
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa-asr/local/text_format.pl
new file mode 100755
index 0000000..45f1f64
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/text_format.pl
@@ -0,0 +1,14 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright Chao Weng 
+
+# normalizations for hkust trascript
+# see the docs/trans-guidelines.pdf for details
+
+while (<STDIN>) {
+  @A = split(" ", $_);
+  if (@A == 1) {
+    next;
+  }
+  print $_
+}
diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa-asr/local/text_normalize.pl
new file mode 100755
index 0000000..ac301d4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/text_normalize.pl
@@ -0,0 +1,38 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright Chao Weng 
+
+# normalizations for hkust trascript
+# see the docs/trans-guidelines.pdf for details
+
+while (<STDIN>) {
+  @A = split(" ", $_);
+  print "$A[0] ";
+  for ($n = 1; $n < @A; $n++) { 
+    $tmp = $A[$n];
+    if ($tmp =~ /<sil>/) {$tmp =~ s:<sil>::g;}
+    if ($tmp =~ /<%>/) {$tmp =~ s:<%>::g;}
+    if ($tmp =~ /<->/) {$tmp =~ s:<->::g;}
+    if ($tmp =~ /<\$>/) {$tmp =~ s:<\$>::g;}
+    if ($tmp =~ /<#>/) {$tmp =~ s:<#>::g;}
+    if ($tmp =~ /<_>/) {$tmp =~ s:<_>::g;}
+    if ($tmp =~ /<space>/) {$tmp =~ s:<space>::g;}
+    if ($tmp =~ /`/) {$tmp =~ s:`::g;}
+    if ($tmp =~ /&/) {$tmp =~ s:&::g;}
+    if ($tmp =~ /,/) {$tmp =~ s:,::g;}
+    if ($tmp =~ /[a-zA-Z]/) {$tmp=uc($tmp);} 
+    if ($tmp =~ /锛�/) {$tmp =~ s:锛�:A:g;}
+    if ($tmp =~ /锝�/) {$tmp =~ s:锝�:A:g;}
+    if ($tmp =~ /锝�/) {$tmp =~ s:锝�:B:g;}
+    if ($tmp =~ /锝�/) {$tmp =~ s:锝�:C:g;}
+    if ($tmp =~ /锝�/) {$tmp =~ s:锝�:K:g;}
+    if ($tmp =~ /锝�/) {$tmp =~ s:锝�:T:g;}
+    if ($tmp =~ /锛�/) {$tmp =~ s:锛�::g;}
+    if ($tmp =~ /涓�/) {$tmp =~ s:涓�::g;}
+    if ($tmp =~ /銆�/) {$tmp =~ s:銆�::g;}
+    if ($tmp =~ /銆�/) {$tmp =~ s:銆�::g;}
+    if ($tmp =~ /锛�/) {$tmp =~ s:锛�::g;}
+    print "$tmp "; 
+  }
+  print "\n"; 
+}
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh
new file mode 100755
index 0000000..3aa13d0
--- /dev/null
+++ b/egs/alimeeting/sa-asr/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
+export PATH=$PWD/utils/:$PATH
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
new file mode 100755
index 0000000..1fd63d6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from io import BytesIO
+from pathlib import Path
+from typing import Tuple, Optional
+
+import kaldiio
+import humanfriendly
+import numpy as np
+import resampy
+import soundfile
+from tqdm import tqdm
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.read_text import read_2column_text
+from funasr.fileio.sound_scp import SoundScpWriter
+
+
+def humanfriendly_or_none(value: str):
+    if value in ("none", "None", "NONE"):
+        return None
+    return humanfriendly.parse_size(value)
+
+
+def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
+    """
+
+    >>> str2int_tuple('3,4,5')
+    (3, 4, 5)
+
+    """
+    assert check_argument_types()
+    if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
+        return None
+    return tuple(map(int, integers.strip().split(",")))
+
+
+def main():
+    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+    logging.basicConfig(level=logging.INFO, format=logfmt)
+    logging.info(get_commandline_args())
+
+    parser = argparse.ArgumentParser(
+        description='Create waves list from "wav.scp"',
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument("scp")
+    parser.add_argument("outdir")
+    parser.add_argument(
+        "--name",
+        default="wav",
+        help="Specify the prefix word of output file name " 'such as "wav.scp"',
+    )
+    parser.add_argument("--segments", default=None)
+    parser.add_argument(
+        "--fs",
+        type=humanfriendly_or_none,
+        default=None,
+        help="If the sampling rate specified, " "Change the sampling rate.",
+    )
+    parser.add_argument("--audio-format", default="wav")
+    group = parser.add_mutually_exclusive_group()
+    group.add_argument("--ref-channels", default=None, type=str2int_tuple)
+    group.add_argument("--utt2ref-channels", default=None, type=str)
+    args = parser.parse_args()
+
+    out_num_samples = Path(args.outdir) / f"utt2num_samples"
+
+    if args.ref_channels is not None:
+
+        def utt2ref_channels(x) -> Tuple[int, ...]:
+            return args.ref_channels
+
+    elif args.utt2ref_channels is not None:
+        utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
+
+        def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
+            chs_str = d[x]
+            return tuple(map(int, chs_str.split()))
+
+    else:
+        utt2ref_channels = None
+
+    Path(args.outdir).mkdir(parents=True, exist_ok=True)
+    out_wavscp = Path(args.outdir) / f"{args.name}.scp"
+    if args.segments is not None:
+        # Note: kaldiio supports only wav-pcm-int16le file.
+        loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
+        if args.audio_format.endswith("ark"):
+            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+            fscp = out_wavscp.open("w")
+        else:
+            writer = SoundScpWriter(
+                args.outdir,
+                out_wavscp,
+                format=args.audio_format,
+            )
+
+        with out_num_samples.open("w") as fnum_samples:
+            for uttid, (rate, wave) in tqdm(loader):
+                # wave: (Time,) or (Time, Nmic)
+                if wave.ndim == 2 and utt2ref_channels is not None:
+                    wave = wave[:, utt2ref_channels(uttid)]
+
+                if args.fs is not None and args.fs != rate:
+                    # FIXME(kamo): To use sox?
+                    wave = resampy.resample(
+                        wave.astype(np.float64), rate, args.fs, axis=0
+                    )
+                    wave = wave.astype(np.int16)
+                    rate = args.fs
+                if args.audio_format.endswith("ark"):
+                    if "flac" in args.audio_format:
+                        suf = "flac"
+                    elif "wav" in args.audio_format:
+                        suf = "wav"
+                    else:
+                        raise RuntimeError("wav.ark or flac")
+
+                    # NOTE(kamo): Using extended ark format style here.
+                    # This format is incompatible with Kaldi
+                    kaldiio.save_ark(
+                        fark,
+                        {uttid: (wave, rate)},
+                        scp=fscp,
+                        append=True,
+                        write_function=f"soundfile_{suf}",
+                    )
+
+                else:
+                    writer[uttid] = rate, wave
+                fnum_samples.write(f"{uttid} {len(wave)}\n")
+    else:
+        if args.audio_format.endswith("ark"):
+            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+        else:
+            wavdir = Path(args.outdir) / f"data_{args.name}"
+            wavdir.mkdir(parents=True, exist_ok=True)
+
+        with Path(args.scp).open("r") as fscp, out_wavscp.open(
+            "w"
+        ) as fout, out_num_samples.open("w") as fnum_samples:
+            for line in tqdm(fscp):
+                uttid, wavpath = line.strip().split(None, 1)
+
+                if wavpath.endswith("|"):
+                    # Streaming input e.g. cat a.wav |
+                    with kaldiio.open_like_kaldi(wavpath, "rb") as f:
+                        with BytesIO(f.read()) as g:
+                            wave, rate = soundfile.read(g, dtype=np.int16)
+                            if wave.ndim == 2 and utt2ref_channels is not None:
+                                wave = wave[:, utt2ref_channels(uttid)]
+
+                        if args.fs is not None and args.fs != rate:
+                            # FIXME(kamo): To use sox?
+                            wave = resampy.resample(
+                                wave.astype(np.float64), rate, args.fs, axis=0
+                            )
+                            wave = wave.astype(np.int16)
+                            rate = args.fs
+
+                        if args.audio_format.endswith("ark"):
+                            if "flac" in args.audio_format:
+                                suf = "flac"
+                            elif "wav" in args.audio_format:
+                                suf = "wav"
+                            else:
+                                raise RuntimeError("wav.ark or flac")
+
+                            # NOTE(kamo): Using extended ark format style here.
+                            # This format is incompatible with Kaldi
+                            kaldiio.save_ark(
+                                fark,
+                                {uttid: (wave, rate)},
+                                scp=fout,
+                                append=True,
+                                write_function=f"soundfile_{suf}",
+                            )
+                        else:
+                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+                            soundfile.write(owavpath, wave, rate)
+                            fout.write(f"{uttid} {owavpath}\n")
+                else:
+                    wave, rate = soundfile.read(wavpath, dtype=np.int16)
+                    if wave.ndim == 2 and utt2ref_channels is not None:
+                        wave = wave[:, utt2ref_channels(uttid)]
+                        save_asis = False
+
+                    elif args.audio_format.endswith("ark"):
+                        save_asis = False
+
+                    elif Path(wavpath).suffix == "." + args.audio_format and (
+                        args.fs is None or args.fs == rate
+                    ):
+                        save_asis = True
+
+                    else:
+                        save_asis = False
+
+                    if save_asis:
+                        # Neither --segments nor --fs are specified and
+                        # the line doesn't end with "|",
+                        # i.e. not using unix-pipe,
+                        # only in this case,
+                        # just using the original file as is.
+                        fout.write(f"{uttid} {wavpath}\n")
+                    else:
+                        if args.fs is not None and args.fs != rate:
+                            # FIXME(kamo): To use sox?
+                            wave = resampy.resample(
+                                wave.astype(np.float64), rate, args.fs, axis=0
+                            )
+                            wave = wave.astype(np.int16)
+                            rate = args.fs
+
+                        if args.audio_format.endswith("ark"):
+                            if "flac" in args.audio_format:
+                                suf = "flac"
+                            elif "wav" in args.audio_format:
+                                suf = "wav"
+                            else:
+                                raise RuntimeError("wav.ark or flac")
+
+                            # NOTE(kamo): Using extended ark format style here.
+                            # This format is not supported in Kaldi.
+                            kaldiio.save_ark(
+                                fark,
+                                {uttid: (wave, rate)},
+                                scp=fout,
+                                append=True,
+                                write_function=f"soundfile_{suf}",
+                            )
+                        else:
+                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+                            soundfile.write(owavpath, wave, rate)
+                            fout.write(f"{uttid} {owavpath}\n")
+                fnum_samples.write(f"{uttid} {len(wave)}\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
new file mode 100755
index 0000000..b0c61e5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+import sys
+
+
+def get_commandline_args(no_executable=True):
+    extra_chars = [
+        " ",
+        ";",
+        "&",
+        "|",
+        "<",
+        ">",
+        "?",
+        "*",
+        "~",
+        "`",
+        '"',
+        "'",
+        "\\",
+        "{",
+        "}",
+        "(",
+        ")",
+    ]
+
+    # Escape the extra characters for shell
+    argv = [
+        arg.replace("'", "'\\''")
+        if all(char not in arg for char in extra_chars)
+        else "'" + arg.replace("'", "'\\''") + "'"
+        for arg in sys.argv
+    ]
+
+    if no_executable:
+        return " ".join(argv[1:])
+    else:
+        return sys.executable + " " + " ".join(argv)
+
+
+def main():
+    print(get_commandline_args())
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023.sh b/egs/alimeeting/sa-asr/run_m2met_2023.sh
new file mode 100755
index 0000000..807e499
--- /dev/null
+++ b/egs/alimeeting/sa-asr/run_m2met_2023.sh
@@ -0,0 +1,51 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+ngpu=4
+device="0,1,2,3"
+
+#stage 1 creat both near and far
+stage=1
+stop_stage=18
+
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_Ali_far"
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+inference_config=conf/decode_asr_rnn.yaml
+
+lm_config=conf/train_lm_transformer.yaml
+use_lm=false
+use_wordlm=false
+./asr_local.sh                                         \
+    --device ${device}                                 \
+    --ngpu ${ngpu}                                     \
+    --stage ${stage}                                   \
+    --stop_stage ${stop_stage}                         \
+    --gpu_inference true    \
+    --njob_infer 4    \
+    --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
+    --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
+    --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
+    --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
+    --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
+    --lang zh                                          \
+    --audio_format wav                                 \
+    --feats_type raw                                   \
+    --token_type char                                  \
+    --use_lm ${use_lm}                                 \
+    --use_word_lm ${use_wordlm}                        \
+    --lm_config "${lm_config}"                         \
+    --asr_config "${asr_config}"                       \
+    --sa_asr_config "${sa_asr_config}"                 \
+    --inference_config "${inference_config}"           \
+    --train_set "${train_set}"                         \
+    --valid_set "${valid_set}"                         \
+    --test_sets "${test_sets}"                         \
+    --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
new file mode 100755
index 0000000..d35e6a6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
@@ -0,0 +1,50 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+ngpu=4
+device="0,1,2,3"
+
+stage=1
+stop_stage=4
+
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_2023_Ali_far"
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+inference_config=conf/decode_asr_rnn.yaml
+
+lm_config=conf/train_lm_transformer.yaml
+use_lm=false
+use_wordlm=false
+./asr_local_infer.sh                                         \
+    --device ${device}                                 \
+    --ngpu ${ngpu}                                     \
+    --stage ${stage}                                   \
+    --stop_stage ${stop_stage}                         \
+    --gpu_inference true    \
+    --njob_infer 4    \
+    --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
+    --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
+    --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
+    --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
+    --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
+    --lang zh                                          \
+    --audio_format wav                                 \
+    --feats_type raw                                   \
+    --token_type char                                  \
+    --use_lm ${use_lm}                                 \
+    --use_word_lm ${use_wordlm}                        \
+    --lm_config "${lm_config}"                         \
+    --asr_config "${asr_config}"                       \
+    --sa_asr_config "${sa_asr_config}"                 \
+    --inference_config "${inference_config}"           \
+    --train_set "${train_set}"                         \
+    --valid_set "${valid_set}"                         \
+    --test_sets "${test_sets}"                         \
+    --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
new file mode 100755
index 0000000..15e4563
--- /dev/null
+++ b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
@@ -0,0 +1,142 @@
+#!/usr/bin/env bash
+set -euo pipefail
+SECONDS=0
+log() {
+    local fname=${BASH_SOURCE[1]##*/}
+    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+help_message=$(cat << EOF
+Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
+e.g.
+$0 data/test/wav.scp data/test_format/
+
+Format 'wav.scp': In short words,
+changing "kaldi-datadir" to "modified-kaldi-datadir"
+
+The 'wav.scp' format in kaldi is very flexible,
+e.g. It can use unix-pipe as describing that wav file,
+but it sometime looks confusing and make scripts more complex.
+This tools creates actual wav files from 'wav.scp'
+and also segments wav files using 'segments'.
+
+Options
+  --fs <fs>
+  --segments <segments>
+  --nj <nj>
+  --cmd <cmd>
+EOF
+)
+
+out_filename=wav.scp
+cmd=utils/run.pl
+nj=30
+fs=none
+segments=
+
+ref_channels=
+utt2ref_channels=
+
+audio_format=wav
+write_utt2num_samples=true
+
+log "$0 $*"
+. utils/parse_options.sh
+
+if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
+    log "${help_message}"
+    log "Error: invalid command line arguments"
+    exit 1
+fi
+
+. ./path.sh  # Setup the environment
+
+scp=$1
+if [ ! -f "${scp}" ]; then
+    log "${help_message}"
+    echo "$0: Error: No such file: ${scp}"
+    exit 1
+fi
+dir=$2
+
+
+if [ $# -eq 2 ]; then
+    logdir=${dir}/logs
+    outdir=${dir}/data
+
+elif [ $# -eq 3 ]; then
+    logdir=$3
+    outdir=${dir}/data
+
+elif [ $# -eq 4 ]; then
+    logdir=$3
+    outdir=$4
+fi
+
+
+mkdir -p ${logdir}
+
+rm -f "${dir}/${out_filename}"
+
+
+opts=
+if [ -n "${utt2ref_channels}" ]; then
+    opts="--utt2ref-channels ${utt2ref_channels} "
+elif [ -n "${ref_channels}" ]; then
+    opts="--ref-channels ${ref_channels} "
+fi
+
+
+if [ -n "${segments}" ]; then
+    log "[info]: using ${segments}"
+    nutt=$(<${segments} wc -l)
+    nj=$((nj<nutt?nj:nutt))
+
+    split_segments=""
+    for n in $(seq ${nj}); do
+        split_segments="${split_segments} ${logdir}/segments.${n}"
+    done
+
+    utils/split_scp.pl "${segments}" ${split_segments}
+
+    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+        pyscripts/audio/format_wav_scp.py \
+            ${opts} \
+            --fs ${fs} \
+            --audio-format "${audio_format}" \
+            "--segment=${logdir}/segments.JOB" \
+            "${scp}" "${outdir}/format.JOB"
+
+else
+    log "[info]: without segments"
+    nutt=$(<${scp} wc -l)
+    nj=$((nj<nutt?nj:nutt))
+
+    split_scps=""
+    for n in $(seq ${nj}); do
+        split_scps="${split_scps} ${logdir}/wav.${n}.scp"
+    done
+
+    utils/split_scp.pl "${scp}" ${split_scps}
+    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+        pyscripts/audio/format_wav_scp.py \
+        ${opts} \
+        --fs "${fs}" \
+        --audio-format "${audio_format}" \
+        "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
+fi
+
+# Workaround for the NFS problem
+ls ${outdir}/format.* > /dev/null
+
+# concatenate the .scp files together.
+for n in $(seq ${nj}); do
+    cat "${outdir}/format.${n}/wav.scp" || exit 1;
+done > "${dir}/${out_filename}" || exit 1
+
+if "${write_utt2num_samples}"; then
+    for n in $(seq ${nj}); do
+        cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
+    done > "${dir}/utt2num_samples"  || exit 1
+fi
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
new file mode 100755
index 0000000..9e08dba
--- /dev/null
+++ b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
@@ -0,0 +1,116 @@
+#!/usr/bin/env bash
+
+# 2020 @kamo-naoyuki
+# This file was copied from Kaldi and 
+# I deleted parts related to wav duration 
+# because we shouldn't use kaldi's command here
+# and we don't need the files actually.
+
+# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
+#           2014  Tom Ko
+#           2018  Emotech LTD (author: Pawel Swietojanski)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+#  wav.scp
+#  spk2utt
+#  utt2spk
+#  text
+#
+# It generates the files which are used for perturbing the speed of the original data.
+
+export LC_ALL=C
+set -euo pipefail
+
+if [[ $# != 3 ]]; then
+    echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
+    echo "e.g.:"
+    echo " $0 0.9 data/train_si284 data/train_si284p"
+    exit 1
+fi
+
+factor=$1
+srcdir=$2
+destdir=$3
+label="sp"
+spk_prefix="${label}${factor}-"
+utt_prefix="${label}${factor}-"
+
+#check is sox on the path
+
+! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
+
+if [[ ! -f ${srcdir}/utt2spk ]]; then
+  echo "$0: no such file ${srcdir}/utt2spk"
+  exit 1;
+fi
+
+if [[ ${destdir} == "${srcdir}" ]]; then
+  echo "$0: this script requires <srcdir> and <destdir> to be different."
+  exit 1
+fi
+
+mkdir -p "${destdir}"
+
+<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
+<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
+<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
+if [[ ! -f ${srcdir}/utt2uniq ]]; then
+    <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
+else
+    <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
+fi
+
+
+<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
+  utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
+
+utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
+
+if [[ -f ${srcdir}/segments ]]; then
+
+  utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
+      utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
+          awk -v factor="${factor}" \
+            '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
+            >"${destdir}"/segments
+
+  utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+      # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+      awk -v factor="${factor}" \
+          '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+            else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+            else  {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+             > "${destdir}"/wav.scp
+  if [[ -f ${srcdir}/reco2file_and_channel ]]; then
+      utils/apply_map.pl -f 1 "${destdir}"/reco_map \
+       <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
+  fi
+
+else # no segments->wav indexed by utterance.
+    if [[ -f ${srcdir}/wav.scp ]]; then
+        utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+         # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+         awk -v factor="${factor}" \
+           '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+             else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+             else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+                 > "${destdir}"/wav.scp
+    fi
+fi
+
+if [[ -f ${srcdir}/text ]]; then
+    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
+fi
+if [[ -f ${srcdir}/spk2gender ]]; then
+    utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
+fi
+if [[ -f ${srcdir}/utt2lang ]]; then
+    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
+fi
+
+rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
+echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
+
+utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
diff --git a/egs/alimeeting/sa-asr/utils/apply_map.pl b/egs/alimeeting/sa-asr/utils/apply_map.pl
new file mode 100755
index 0000000..725d346
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/apply_map.pl
@@ -0,0 +1,97 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright 2012  Johns Hopkins University (Author: Daniel Povey)
+# Apache 2.0.
+
+# This program is a bit like ./sym2int.pl in that it applies a map
+# to things in a file, but it's a bit more general in that it doesn't
+# assume the things being mapped to are single tokens, they could
+# be sequences of tokens.  See the usage message.
+
+
+$permissive = 0;
+
+for ($x = 0; $x <= 2; $x++) {
+
+  if (@ARGV > 0 && $ARGV[0] eq "-f") {
+    shift @ARGV;
+    $field_spec = shift @ARGV;
+    if ($field_spec =~ m/^\d+$/) {
+      $field_begin = $field_spec - 1; $field_end = $field_spec - 1;
+    }
+    if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
+      if ($1 ne "") {
+        $field_begin = $1 - 1;  # Change to zero-based indexing.
+      }
+      if ($2 ne "") {
+        $field_end = $2 - 1;    # Change to zero-based indexing.
+      }
+    }
+    if (!defined $field_begin && !defined $field_end) {
+      die "Bad argument to -f option: $field_spec";
+    }
+  }
+
+  if (@ARGV > 0 && $ARGV[0] eq '--permissive') {
+    shift @ARGV;
+    # Mapping is optional (missing key is printed to output)
+    $permissive = 1;
+  }
+}
+
+if(@ARGV != 1) {
+  print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n";
+  print STDERR <<'EOF';
+Usage: apply_map.pl [options] map <input >output
+ options: [-f <field-range> ] [--permissive]
+   This applies a map to some specified fields of some input text:
+   For each line in the map file: the first field is the thing we
+   map from, and the remaining fields are the sequence we map it to.
+   The -f (field-range) option says which fields of the input file the map
+   map should apply to.
+   If the --permissive option is supplied, fields which are not present
+   in the map will be left as they were.
+ Applies the map 'map' to all input text, where each line of the map
+ is interpreted as a map from the first field to the list of the other fields
+ Note: <field-range> can look like 4-5, or 4-, or 5-, or 1, it means the field
+ range in the input to apply the map to.
+ e.g.: echo A B | apply_map.pl a.txt
+ where a.txt is:
+ A a1 a2
+ B b
+ will produce:
+ a1 a2 b
+EOF
+  exit(1);
+}
+
+($map_file) = @ARGV;
+open(M, "<$map_file") || die "Error opening map file $map_file: $!";
+
+while (<M>) {
+  @A = split(" ", $_);
+  @A >= 1 || die "apply_map.pl: empty line.";
+  $i = shift @A;
+  $o = join(" ", @A);
+  $map{$i} = $o;
+}
+
+while(<STDIN>) {
+  @A = split(" ", $_);
+  for ($x = 0; $x < @A; $x++) {
+    if ( (!defined $field_begin || $x >= $field_begin)
+         && (!defined $field_end || $x <= $field_end)) {
+      $a = $A[$x];
+      if (!defined $map{$a}) {
+        if (!$permissive) {
+          die "apply_map.pl: undefined key $a in $map_file\n";
+        } else {
+          print STDERR "apply_map.pl: warning! missing key $a in $map_file\n";
+        }
+      } else {
+        $A[$x] = $map{$a};
+      }
+    }
+  }
+  print join(" ", @A) . "\n";
+}
diff --git a/egs/alimeeting/sa-asr/utils/combine_data.sh b/egs/alimeeting/sa-asr/utils/combine_data.sh
new file mode 100755
index 0000000..e1eba85
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/combine_data.sh
@@ -0,0 +1,146 @@
+#!/usr/bin/env bash
+# Copyright 2012  Johns Hopkins University (Author: Daniel Povey).  Apache 2.0.
+#           2014  David Snyder
+
+# This script combines the data from multiple source directories into
+# a single destination directory.
+
+# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
+# about what these directories contain.
+
+# Begin configuration section.
+extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
+skip_fix=false # skip the fix_data_dir.sh in the end
+# End configuration section.
+
+echo "$0 $@"  # Print the command line for logging
+
+if [ -f path.sh ]; then . ./path.sh; fi
+. parse_options.sh || exit 1;
+
+if [ $# -lt 2 ]; then
+  echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
+  echo "Note, files that don't appear in all source dirs will not be combined,"
+  echo "with the exception of utt2uniq and segments, which are created where necessary."
+  exit 1
+fi
+
+dest=$1;
+shift;
+
+first_src=$1;
+
+rm -r $dest 2>/dev/null || true
+mkdir -p $dest;
+
+export LC_ALL=C
+
+for dir in $*; do
+  if [ ! -f $dir/utt2spk ]; then
+    echo "$0: no such file $dir/utt2spk"
+    exit 1;
+  fi
+done
+
+# Check that frame_shift are compatible, where present together with features.
+dir_with_frame_shift=
+for dir in $*; do
+  if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
+    if [[ $dir_with_frame_shift ]] &&
+       ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
+      echo "$0:error: different frame_shift in directories $dir and " \
+           "$dir_with_frame_shift. Cannot combine features."
+      exit 1;
+    fi
+    dir_with_frame_shift=$dir
+  fi
+done
+
+# W.r.t. utt2uniq file the script has different behavior compared to other files
+# it is not compulsary for it to exist in src directories, but if it exists in
+# even one it should exist in all. We will create the files where necessary
+has_utt2uniq=false
+for in_dir in $*; do
+  if [ -f $in_dir/utt2uniq ]; then
+    has_utt2uniq=true
+    break
+  fi
+done
+
+if $has_utt2uniq; then
+  # we are going to create an utt2uniq file in the destdir
+  for in_dir in $*; do
+    if [ ! -f $in_dir/utt2uniq ]; then
+      # we assume that utt2uniq is a one to one mapping
+      cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
+    else
+      cat $in_dir/utt2uniq
+    fi
+  done | sort -k1 > $dest/utt2uniq
+  echo "$0: combined utt2uniq"
+else
+  echo "$0 [info]: not combining utt2uniq as it does not exist"
+fi
+# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
+extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
+
+# segments are treated similarly to utt2uniq. If it exists in some, but not all
+# src directories, then we generate segments where necessary.
+has_segments=false
+for in_dir in $*; do
+  if [ -f $in_dir/segments ]; then
+    has_segments=true
+    break
+  fi
+done
+
+if $has_segments; then
+  for in_dir in $*; do
+    if [ ! -f $in_dir/segments ]; then
+      echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
+      utils/data/get_segments_for_data.sh $in_dir
+    else
+      cat $in_dir/segments
+    fi
+  done | sort -k1 > $dest/segments
+  echo "$0: combined segments"
+else
+  echo "$0 [info]: not combining segments as it does not exist"
+fi
+
+for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
+  exists_somewhere=false
+  absent_somewhere=false
+  for d in $*; do
+    if [ -f $d/$file ]; then
+      exists_somewhere=true
+    else
+      absent_somewhere=true
+      fi
+  done
+
+  if ! $absent_somewhere; then
+    set -o pipefail
+    ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
+    set +o pipefail
+    echo "$0: combined $file"
+  else
+    if ! $exists_somewhere; then
+      echo "$0 [info]: not combining $file as it does not exist"
+    else
+      echo "$0 [info]: **not combining $file as it does not exist everywhere**"
+    fi
+  fi
+done
+
+utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
+
+if [[ $dir_with_frame_shift ]]; then
+  cp $dir_with_frame_shift/frame_shift $dest
+fi
+
+if ! $skip_fix ; then
+  utils/fix_data_dir.sh $dest || exit 1;
+fi
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh
new file mode 100755
index 0000000..9fd420c
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh
@@ -0,0 +1,145 @@
+#!/usr/bin/env bash
+
+# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+#  feats.scp
+#  wav.scp
+#  vad.scp
+#  spk2utt
+#  utt2spk
+#  text
+#
+# It copies to another directory, possibly adding a specified prefix or a suffix
+# to the utterance and/or speaker names.  Note, the recording-ids stay the same.
+#
+
+
+# begin configuration section
+spk_prefix=
+utt_prefix=
+spk_suffix=
+utt_suffix=
+validate_opts=   # should rarely be needed.
+# end configuration section
+
+. utils/parse_options.sh
+
+if [ $# != 2 ]; then
+  echo "Usage: "
+  echo "  $0 [options] <srcdir> <destdir>"
+  echo "e.g.:"
+  echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1"
+  echo "Options"
+  echo "   --spk-prefix=<prefix>     # Prefix for speaker ids, default empty"
+  echo "   --utt-prefix=<prefix>     # Prefix for utterance ids, default empty"
+  echo "   --spk-suffix=<suffix>     # Suffix for speaker ids, default empty"
+  echo "   --utt-suffix=<suffix>     # Suffix for utterance ids, default empty"
+  exit 1;
+fi
+
+
+export LC_ALL=C
+
+srcdir=$1
+destdir=$2
+
+if [ ! -f $srcdir/utt2spk ]; then
+  echo "copy_data_dir.sh: no such file $srcdir/utt2spk"
+  exit 1;
+fi
+
+if [ "$destdir" == "$srcdir" ]; then
+  echo "$0: this script requires <srcdir> and <destdir> to be different."
+  exit 1
+fi
+
+set -e;
+
+mkdir -p $destdir
+
+cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map
+cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map
+
+if [ ! -f $srcdir/utt2uniq ]; then
+  if [[ ! -z $utt_prefix  ||  ! -z $utt_suffix ]]; then
+    cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq
+  fi
+else
+  cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
+fi
+
+cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map  | \
+  utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
+
+utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
+
+if [ -f $srcdir/feats.scp ]; then
+  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
+fi
+
+if [ -f $srcdir/vad.scp ]; then
+  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
+fi
+
+if [ -f $srcdir/segments ]; then
+  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
+  cp $srcdir/wav.scp $destdir
+else # no segments->wav indexed by utt.
+  if [ -f $srcdir/wav.scp ]; then
+    utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
+  fi
+fi
+
+if [ -f $srcdir/reco2file_and_channel ]; then
+  cp $srcdir/reco2file_and_channel $destdir/
+fi
+
+if [ -f $srcdir/text ]; then
+  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
+fi
+if [ -f $srcdir/utt2dur ]; then
+  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
+fi
+if [ -f $srcdir/utt2num_frames ]; then
+  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
+fi
+if [ -f $srcdir/reco2dur ]; then
+  if [ -f $srcdir/segments ]; then
+    cp $srcdir/reco2dur $destdir/reco2dur
+  else
+    utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
+  fi
+fi
+if [ -f $srcdir/spk2gender ]; then
+  utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
+fi
+if [ -f $srcdir/cmvn.scp ]; then
+  utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
+fi
+for f in frame_shift stm glm ctm; do
+  if [ -f $srcdir/$f ]; then
+    cp $srcdir/$f $destdir
+  fi
+done
+
+rm $destdir/spk_map $destdir/utt_map
+
+echo "$0: copied data from $srcdir to $destdir"
+
+for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do
+  if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then
+    echo "$0: file $f exists in dest $destdir but not in src $srcdir.  Moving it to"
+    echo " ... $destdir/.backup/$f"
+    mkdir -p $destdir/.backup
+    mv $destdir/$f $destdir/.backup/
+  fi
+done
+
+
+[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
+[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
+
+utils/validate_data_dir.sh $validate_opts $destdir
diff --git a/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh
new file mode 100755
index 0000000..24f51e7
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh
@@ -0,0 +1,143 @@
+#!/usr/bin/env bash
+
+# Copyright 2016  Johns Hopkins University (author: Daniel Povey)
+#           2018  Andrea Carmantini
+# Apache 2.0
+
+# This script operates on a data directory, such as in data/train/, and adds the
+# reco2dur file if it does not already exist.  The file 'reco2dur' maps from
+# recording to the duration of the recording in seconds.  This script works it
+# out from the 'wav.scp' file, or, if utterance-ids are the same as recording-ids, from the
+# utt2dur file (it first tries interrogating the headers, and if this fails, it reads the wave
+# files in entirely.)
+# We could use durations from segments file, but that's not the duration of the recordings
+# but the sum of utterance lenghts (silence in between could be excluded from segments)
+# For sum of utterance lenghts:
+# awk 'FNR==NR{uttdur[$1]=$2;next}
+# { for(i=2;i<=NF;i++){dur+=uttdur[$i];}
+#   print $1 FS dur; dur=0  }'  $data/utt2dur $data/reco2utt
+
+
+frame_shift=0.01
+cmd=run.pl
+nj=4
+
+. utils/parse_options.sh
+. ./path.sh
+
+if [ $# != 1 ]; then
+  echo "Usage: $0 [options] <datadir>"
+  echo "e.g.:"
+  echo " $0 data/train"
+  echo " Options:"
+  echo " --frame-shift      # frame shift in seconds. Only relevant when we are"
+  echo "                    # getting duration from feats.scp (default: 0.01). "
+  exit 1
+fi
+
+export LC_ALL=C
+
+data=$1
+
+
+if [ -s $data/reco2dur ] && \
+  [ $(wc -l < $data/wav.scp) -eq $(wc -l < $data/reco2dur) ]; then
+  echo "$0: $data/reco2dur already exists with the expected length.  We won't recompute it."
+  exit 0;
+fi
+
+if [ -s $data/utt2dur ] && \
+   [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ] && \
+   [ ! -s $data/segments ]; then
+
+  echo "$0: $data/wav.scp indexed by utt-id; copying utt2dur to reco2dur"
+  cp $data/utt2dur $data/reco2dur && exit 0;
+
+elif [ -f $data/wav.scp ]; then
+  echo "$0: obtaining durations from recordings"
+
+  # if the wav.scp contains only lines of the form
+  # utt1  /foo/bar/sph2pipe -f wav /baz/foo.sph |
+  if cat $data/wav.scp | perl -e '
+     while (<>) { s/\|\s*$/ |/;  # make sure final | is preceded by space.
+             @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
+                               $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
+             $reco = $A[0]; $sphere_file = $A[4];
+
+             if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
+             $sample_rate = -1;  $sample_count = -1;
+             for ($n = 0; $n <= 30; $n++) {
+                $line = <F>;
+                if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
+                if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
+                if ($line =~ m/end_head/) { break; }
+             }
+             close(F);
+             if ($sample_rate == -1 || $sample_count == -1) {
+               die "could not parse sphere header from $sphere_file";
+             }
+             $duration = $sample_count * 1.0 / $sample_rate;
+             print "$reco $duration\n";
+     } ' > $data/reco2dur; then
+    echo "$0: successfully obtained recording lengths from sphere-file headers"
+  else
+    echo "$0: could not get recording lengths from sphere-file headers, using wav-to-duration"
+    if ! command -v wav-to-duration >/dev/null; then
+      echo  "$0: wav-to-duration is not on your path"
+      exit 1;
+    fi
+
+    read_entire_file=false
+    if grep -q 'sox.*speed' $data/wav.scp; then
+      read_entire_file=true
+      echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
+      echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
+      echo "... perturb_data_dir_speed_3way.sh."
+    fi
+
+    num_recos=$(wc -l <$data/wav.scp)
+    if [ $nj -gt $num_recos ]; then
+      nj=$num_recos
+    fi
+
+    temp_data_dir=$data/wav${nj}split
+    wavscps=$(for n in `seq $nj`; do echo $temp_data_dir/$n/wav.scp; done)
+    subdirs=$(for n in `seq $nj`; do echo $temp_data_dir/$n; done)
+
+    if ! mkdir -p $subdirs >&/dev/null; then
+	for n in `seq $nj`; do
+	    mkdir -p $temp_data_dir/$n
+	done
+    fi
+
+    utils/split_scp.pl $data/wav.scp $wavscps
+
+
+    $cmd JOB=1:$nj $data/log/get_reco_durations.JOB.log \
+      wav-to-duration --read-entire-file=$read_entire_file \
+      scp:$temp_data_dir/JOB/wav.scp ark,t:$temp_data_dir/JOB/reco2dur || \
+        { echo "$0: there was a problem getting the durations"; exit 1; } # This could
+
+    for n in `seq $nj`; do
+      cat $temp_data_dir/$n/reco2dur
+    done > $data/reco2dur
+  fi
+  rm -r $temp_data_dir
+else
+  echo "$0: Expected $data/wav.scp to exist"
+  exit 1
+fi
+
+len1=$(wc -l < $data/wav.scp)
+len2=$(wc -l < $data/reco2dur)
+if [ "$len1" != "$len2" ]; then
+  echo "$0: warning: length of reco2dur does not equal that of wav.scp, $len2 != $len1"
+  if [ $len1 -gt $[$len2*2] ]; then
+    echo "$0: less than half of recordings got a duration: failing."
+    exit 1
+  fi
+fi
+
+echo "$0: computed $data/reco2dur"
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
new file mode 100755
index 0000000..6b161b3
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+
+# This script operates on a data directory, such as in data/train/,
+# and writes new segments to stdout. The file 'segments' maps from
+# utterance to time offsets into a recording, with the format:
+#   <utterance-id> <recording-id> <segment-begin> <segment-end>
+# This script assumes utterance and recording ids are the same (i.e., that
+# wav.scp is indexed by utterance), and uses durations from 'utt2dur', 
+# created if necessary by get_utt2dur.sh.
+
+. ./path.sh
+
+if [ $# != 1 ]; then
+  echo "Usage: $0 [options] <datadir>"
+  echo "e.g.:"
+  echo " $0 data/train > data/train/segments"
+  exit 1
+fi
+
+data=$1
+
+if [ ! -s $data/utt2dur ]; then
+  utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
+fi
+
+# <utt-id> <utt-id> 0 <utt-dur>
+awk '{ print $1, $1, 0, $2 }' $data/utt2dur
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
new file mode 100755
index 0000000..5ee7ea3
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
@@ -0,0 +1,135 @@
+#!/usr/bin/env bash
+
+# Copyright 2016  Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+# This script operates on a data directory, such as in data/train/, and adds the
+# utt2dur file if it does not already exist.  The file 'utt2dur' maps from
+# utterance to the duration of the utterance in seconds.  This script works it
+# out from the 'segments' file, or, if not present, from the wav.scp file (it
+# first tries interrogating the headers, and if this fails, it reads the wave
+# files in entirely.)
+
+frame_shift=0.01
+cmd=run.pl
+nj=4
+read_entire_file=false
+
+. utils/parse_options.sh
+. ./path.sh
+
+if [ $# != 1 ]; then
+  echo "Usage: $0 [options] <datadir>"
+  echo "e.g.:"
+  echo " $0 data/train"
+  echo " Options:"
+  echo " --frame-shift      # frame shift in seconds. Only relevant when we are"
+  echo "                    # getting duration from feats.scp, and only if the "
+  echo "                    # file frame_shift does not exist (default: 0.01). "
+  exit 1
+fi
+
+export LC_ALL=C
+
+data=$1
+
+if [ -s $data/utt2dur ] && \
+  [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ]; then
+  echo "$0: $data/utt2dur already exists with the expected length.  We won't recompute it."
+  exit 0;
+fi
+
+if [ -s $data/segments ]; then
+  echo "$0: working out $data/utt2dur from $data/segments"
+  awk '{len=$4-$3; print $1, len;}' < $data/segments  > $data/utt2dur
+elif [[ -s $data/frame_shift && -f $data/utt2num_frames ]]; then
+  echo "$0: computing $data/utt2dur from $data/{frame_shift,utt2num_frames}."
+  frame_shift=$(cat $data/frame_shift) || exit 1
+  # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
+  awk -v fs=$frame_shift '{ $2=($2+1.5)*fs; print }' <$data/utt2num_frames  >$data/utt2dur
+elif [ -f $data/wav.scp ]; then
+  echo "$0: segments file does not exist so getting durations from wave files"
+
+  # if the wav.scp contains only lines of the form
+  # utt1  /foo/bar/sph2pipe -f wav /baz/foo.sph |
+  if perl <$data/wav.scp -e '
+     while (<>) { s/\|\s*$/ |/;  # make sure final | is preceded by space.
+             @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
+                               $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
+             $utt = $A[0]; $sphere_file = $A[4];
+
+             if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
+             $sample_rate = -1;  $sample_count = -1;
+             for ($n = 0; $n <= 30; $n++) {
+                $line = <F>;
+                if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
+                if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
+                if ($line =~ m/end_head/) { break; }
+             }
+             close(F);
+             if ($sample_rate == -1 || $sample_count == -1) {
+               die "could not parse sphere header from $sphere_file";
+             }
+             $duration = $sample_count * 1.0 / $sample_rate;
+             print "$utt $duration\n";
+     } ' > $data/utt2dur; then
+    echo "$0: successfully obtained utterance lengths from sphere-file headers"
+  else
+    echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration"
+    if ! command -v wav-to-duration >/dev/null; then
+      echo  "$0: wav-to-duration is not on your path"
+      exit 1;
+    fi
+
+    if grep -q 'sox.*speed' $data/wav.scp; then
+      read_entire_file=true
+      echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
+      echo "... It is much faster if you call get_utt2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
+      echo "... perturb_data_dir_speed_3way.sh."
+    fi
+
+
+    num_utts=$(wc -l <$data/utt2spk)
+    if [ $nj -gt $num_utts ]; then
+      nj=$num_utts
+    fi
+
+    utils/data/split_data.sh --per-utt $data $nj
+    sdata=$data/split${nj}utt
+
+    $cmd JOB=1:$nj $data/log/get_durations.JOB.log \
+      wav-to-duration --read-entire-file=$read_entire_file \
+      scp:$sdata/JOB/wav.scp ark,t:$sdata/JOB/utt2dur || \
+        { echo "$0: there was a problem getting the durations"; exit 1; }
+
+    for n in `seq $nj`; do
+      cat $sdata/$n/utt2dur
+    done > $data/utt2dur
+  fi
+elif [ -f $data/feats.scp ]; then
+  echo "$0: wave file does not exist so getting durations from feats files"
+  if [[ -s $data/frame_shift ]]; then
+    frame_shift=$(cat $data/frame_shift) || exit 1
+    echo "$0: using frame_shift=$frame_shift from file $data/frame_shift"
+  fi
+  # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
+  feat-to-len scp:$data/feats.scp ark,t:- |
+    awk -v frame_shift=$frame_shift '{print $1, ($2+1.5)*frame_shift}' >$data/utt2dur
+else
+  echo "$0: Expected $data/wav.scp, $data/segments or $data/feats.scp to exist"
+  exit 1
+fi
+
+len1=$(wc -l < $data/utt2spk)
+len2=$(wc -l < $data/utt2dur)
+if [ "$len1" != "$len2" ]; then
+  echo "$0: warning: length of utt2dur does not equal that of utt2spk, $len2 != $len1"
+  if [ $len1 -gt $[$len2*2] ]; then
+    echo "$0: less than half of utterances got a duration: failing."
+    exit 1
+  fi
+fi
+
+echo "$0: computed $data/utt2dur"
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/data/split_data.sh b/egs/alimeeting/sa-asr/utils/data/split_data.sh
new file mode 100755
index 0000000..8aa71a1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/split_data.sh
@@ -0,0 +1,160 @@
+#!/usr/bin/env bash
+# Copyright 2010-2013 Microsoft Corporation
+#                     Johns Hopkins University (Author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+split_per_spk=true
+if [ "$1" == "--per-utt" ]; then
+  split_per_spk=false
+  shift
+fi
+
+if [ $# != 2 ]; then
+  echo "Usage: $0 [--per-utt] <data-dir> <num-to-split>"
+  echo "E.g.: $0 data/train 50"
+  echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the "
+  echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}."
+  echo ""
+  echo "This script will not split the data-dir if it detects that the output is newer than the input."
+  echo "By default it splits per speaker (so each speaker is in only one split dir),"
+  echo "but with the --per-utt option it will ignore the speaker information while splitting."
+  exit 1
+fi
+
+data=$1
+numsplit=$2
+
+if ! [ "$numsplit" -gt 0 ]; then
+  echo "Invalid num-split argument $numsplit";
+  exit 1;
+fi
+
+if $split_per_spk; then
+  warning_opt=
+else
+  # suppress warnings from filter_scps.pl about 'some input lines were output
+  # to multiple files'.
+  warning_opt="--no-warn"
+fi
+
+n=0;
+feats=""
+wavs=""
+utt2spks=""
+texts=""
+
+nu=`cat $data/utt2spk | wc -l`
+nf=`cat $data/feats.scp 2>/dev/null | wc -l`
+nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
+if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
+  echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
+  echo "**  use utils/fix_data_dir.sh $data to fix this."
+fi
+if [ -f $data/text ] && [ $nu -ne $nt ]; then
+  echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
+  echo "** use utils/fix_data_dir.sh to fix this."
+fi
+
+
+if $split_per_spk; then
+  utt2spk_opt="--utt2spk=$data/utt2spk"
+  utt=""
+else
+  utt2spk_opt=
+  utt="utt"
+fi
+
+s1=$data/split${numsplit}${utt}/1
+if [ ! -d $s1 ]; then
+  need_to_split=true
+else
+  need_to_split=false
+  for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \
+    vad.scp segments reco2file_and_channel utt2lang; do
+    if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then
+      need_to_split=true
+    fi
+  done
+fi
+
+if ! $need_to_split; then
+  exit 0;
+fi
+
+utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done)
+
+directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done)
+
+# if this mkdir fails due to argument-list being too long, iterate.
+if ! mkdir -p $directories >&/dev/null; then
+  for n in `seq $numsplit`; do
+    mkdir -p $data/split${numsplit}${utt}/$n
+  done
+fi
+
+# If lockfile is not installed, just don't lock it.  It's not a big deal.
+which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock
+trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM
+
+utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
+
+for n in `seq $numsplit`; do
+  dsn=$data/split${numsplit}${utt}/$n
+  utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
+done
+
+maybe_wav_scp=
+if [ ! -f $data/segments ]; then
+  maybe_wav_scp=wav.scp  # If there is no segments file, then wav file is
+                         # indexed per utt.
+fi
+
+# split some things that are indexed by utterance.
+for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do
+  if [ -f $data/$f ]; then
+    utils/filter_scps.pl JOB=1:$numsplit \
+      $data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
+  fi
+done
+
+# split some things that are indexed by speaker
+for f in spk2gender spk2warp cmvn.scp; do
+  if [ -f $data/$f ]; then
+    utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+      $data/split${numsplit}${utt}/JOB/spk2utt $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
+  fi
+done
+
+if [ -f $data/segments ]; then
+  utils/filter_scps.pl JOB=1:$numsplit \
+     $data/split${numsplit}${utt}/JOB/utt2spk $data/segments $data/split${numsplit}${utt}/JOB/segments || exit 1
+  for n in `seq $numsplit`; do
+    dsn=$data/split${numsplit}${utt}/$n
+    awk '{print $2;}' $dsn/segments | sort | uniq > $dsn/tmp.reco # recording-ids.
+  done
+  if [ -f $data/reco2file_and_channel ]; then
+    utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+      $data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2file_and_channel \
+      $data/split${numsplit}${utt}/JOB/reco2file_and_channel || exit 1
+  fi
+  if [ -f $data/wav.scp ]; then
+    utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+      $data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \
+      $data/split${numsplit}${utt}/JOB/wav.scp || exit 1
+  fi
+  for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done
+fi
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/filter_scp.pl b/egs/alimeeting/sa-asr/utils/filter_scp.pl
new file mode 100755
index 0000000..b76d37f
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/filter_scp.pl
@@ -0,0 +1,87 @@
+#!/usr/bin/env perl
+# Copyright 2010-2012 Microsoft Corporation
+#                     Johns Hopkins University (author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This script takes a list of utterance-ids or any file whose first field
+# of each line is an utterance-id, and filters an scp
+# file (or any file whose "n-th" field is an utterance id), printing
+# out only those lines whose "n-th" field is in id_list. The index of
+# the "n-th" field is 1, by default, but can be changed by using
+# the -f <n> switch
+
+$exclude = 0;
+$field = 1;
+$shifted = 0;
+
+do {
+  $shifted=0;
+  if ($ARGV[0] eq "--exclude") {
+    $exclude = 1;
+    shift @ARGV;
+    $shifted=1;
+  }
+  if ($ARGV[0] eq "-f") {
+    $field = $ARGV[1];
+    shift @ARGV; shift @ARGV;
+    $shifted=1
+  }
+} while ($shifted);
+
+if(@ARGV < 1 || @ARGV > 2) {
+  die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
+      "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
+      "Note: only the first field of each line in id_list matters.  With --exclude, prints\n" .
+      "only the lines that were *not* in id_list.\n" .
+      "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
+      "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
+      "-f option, add 1 to the argument.\n" .
+      "See also: utils/filter_scp.pl .\n";
+}
+
+
+$idlist = shift @ARGV;
+open(F, "<$idlist") || die "Could not open id-list file $idlist";
+while(<F>) {
+  @A = split;
+  @A>=1 || die "Invalid id-list file line $_";
+  $seen{$A[0]} = 1;
+}
+
+if ($field == 1) { # Treat this as special case, since it is common.
+  while(<>) {
+    $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
+    # $1 is what we filter on.
+    if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
+      print $_;
+    }
+  }
+} else {
+  while(<>) {
+    @A = split;
+    @A > 0 || die "Invalid scp file line $_";
+    @A >= $field || die "Invalid scp file line $_";
+    if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
+      print $_;
+    }
+  }
+}
+
+# tests:
+# the following should print "foo 1"
+# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
+# the following should print "bar 2".
+# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
diff --git a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh
new file mode 100755
index 0000000..ed4710d
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh
@@ -0,0 +1,215 @@
+#!/usr/bin/env bash
+
+# This script makes sure that only the segments present in
+# all of "feats.scp", "wav.scp" [if present], segments [if present]
+# text, and utt2spk are present in any of them.
+# It puts the original contents of data-dir into
+# data-dir/.backup
+
+cmd="$@"
+
+utt_extra_files=
+spk_extra_files=
+
+. utils/parse_options.sh
+
+if [ $# != 1 ]; then
+  echo "Usage: utils/data/fix_data_dir.sh <data-dir>"
+  echo "e.g.: utils/data/fix_data_dir.sh data/train"
+  echo "This script helps ensure that the various files in a data directory"
+  echo "are correctly sorted and filtered, for example removing utterances"
+  echo "that have no features (if feats.scp is present)"
+  exit 1
+fi
+
+data=$1
+
+if [ -f $data/images.scp ]; then
+  image/fix_data_dir.sh $cmd
+  exit $?
+fi
+
+mkdir -p $data/.backup
+
+[ ! -d $data ] && echo "$0: no such directory $data" && exit 1;
+
+[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1;
+
+set -e -o pipefail -u
+
+tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
+trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
+
+export LC_ALL=C
+
+function check_sorted {
+  file=$1
+  sort -k1,1 -u <$file >$file.tmp
+  if ! cmp -s $file $file.tmp; then
+    echo "$0: file $1 is not in sorted order or not unique, sorting it"
+    mv $file.tmp $file
+  else
+    rm $file.tmp
+  fi
+}
+
+for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \
+    reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do
+  if [ -f $data/$x ]; then
+    cp $data/$x $data/.backup/$x
+    check_sorted $data/$x
+  fi
+done
+
+
+function filter_file {
+  filter=$1
+  file_to_filter=$2
+  cp $file_to_filter ${file_to_filter}.tmp
+  utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter
+  if ! cmp ${file_to_filter}.tmp  $file_to_filter >&/dev/null; then
+    length1=$(cat ${file_to_filter}.tmp | wc -l)
+    length2=$(cat ${file_to_filter} | wc -l)
+    if [ $length1 -ne $length2 ]; then
+      echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter."
+    fi
+  fi
+  rm $file_to_filter.tmp
+}
+
+function filter_recordings {
+  # We call this once before the stage when we filter on utterance-id, and once
+  # after.
+
+  if [ -f $data/segments ]; then
+  # We have a segments file -> we need to filter this and the file wav.scp, and
+  # reco2file_and_utt, if it exists, to make sure they have the same list of
+  # recording-ids.
+
+    if [ ! -f $data/wav.scp ]; then
+      echo "$0: $data/segments exists but not $data/wav.scp"
+      exit 1;
+    fi
+    awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings
+    n1=$(cat $tmpdir/recordings | wc -l)
+    [ ! -s $tmpdir/recordings ] && \
+      echo "Empty list of recordings (bad file $data/segments)?" && exit 1;
+    utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp
+    mv $tmpdir/recordings.tmp $tmpdir/recordings
+
+
+    cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
+    filter_file $tmpdir/recordings $data/segments
+    cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
+    rm $data/segments.tmp
+
+    filter_file $tmpdir/recordings $data/wav.scp
+    [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel
+    [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur
+    true
+  fi
+}
+
+function filter_speakers {
+  # throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
+  utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+
+  cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+  for s in cmvn.scp spk2gender; do
+    f=$data/$s
+    if [ -f $f ]; then
+      filter_file $f $tmpdir/speakers
+    fi
+  done
+
+  filter_file $tmpdir/speakers $data/spk2utt
+  utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
+
+  for s in cmvn.scp spk2gender $spk_extra_files; do
+    f=$data/$s
+    if [ -f $f ]; then
+      filter_file $tmpdir/speakers $f
+    fi
+  done
+}
+
+function filter_utts {
+  cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
+
+  ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \
+    echo "utt2spk is not in sorted order (fix this yourself)" && exit 1;
+
+  ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \
+    echo "utt2spk is not in sorted order when sorted first on speaker-id " && \
+    echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
+
+  ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \
+    echo "spk2utt is not in sorted order (fix this yourself)" && exit 1;
+
+  if [ -f $data/utt2uniq ]; then
+    ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \
+      echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1;
+  fi
+
+  maybe_wav=
+  maybe_reco2dur=
+  [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist.
+  [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts
+
+  maybe_utt2dur=
+  if [ -f $data/utt2dur ]; then
+    cat $data/utt2dur | \
+      awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1
+    maybe_utt2dur=utt2dur.ok
+  fi
+
+  maybe_utt2num_frames=
+  if [ -f $data/utt2num_frames ]; then
+    cat $data/utt2num_frames | \
+      awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1
+    maybe_utt2num_frames=utt2num_frames.ok
+  fi
+
+  for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do
+    if [ -f $data/$x ]; then
+      utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp
+      mv $tmpdir/utts.tmp $tmpdir/utts
+    fi
+  done
+  rm $data/utt2dur.ok 2>/dev/null || true
+  rm $data/utt2num_frames.ok 2>/dev/null || true
+
+  [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \
+    rm $tmpdir/utts && exit 1;
+
+
+  if [ -f $data/utt2spk ]; then
+    new_nutts=$(cat $tmpdir/utts | wc -l)
+    old_nutts=$(cat $data/utt2spk | wc -l)
+    if [ $new_nutts -ne $old_nutts ]; then
+      echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts"
+    else
+      echo "fix_data_dir.sh: kept all $old_nutts utterances."
+    fi
+  fi
+
+  for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do
+    if [ -f $data/$x ]; then
+      cp $data/$x $data/.backup/$x
+      if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then
+        utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x
+      fi
+    fi
+  done
+
+}
+
+filter_recordings
+filter_speakers
+filter_utts
+filter_speakers
+filter_recordings
+
+utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+
+echo "fix_data_dir.sh: old files are kept in $data/.backup"
diff --git a/egs/alimeeting/sa-asr/utils/parse_options.sh b/egs/alimeeting/sa-asr/utils/parse_options.sh
new file mode 100755
index 0000000..71fb9e5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/parse_options.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2012  Johns Hopkins University (Author: Daniel Povey);
+#                 Arnab Ghoshal, Karel Vesely
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parse command-line options.
+# To be sourced by another script (as in ". parse_options.sh").
+# Option format is: --option-name arg
+# and shell variable "option_name" gets set to value "arg."
+# The exception is --help, which takes no arguments, but prints the
+# $help_message variable (if defined).
+
+
+###
+### The --config file options have lower priority to command line
+### options, so we need to import them first...
+###
+
+# Now import all the configs specified by command-line, in left-to-right order
+for ((argpos=1; argpos<$#; argpos++)); do
+  if [ "${!argpos}" == "--config" ]; then
+    argpos_plus1=$((argpos+1))
+    config=${!argpos_plus1}
+    [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
+    . $config  # source the config file.
+  fi
+done
+
+
+###
+### Now we process the command line options
+###
+while true; do
+  [ -z "${1:-}" ] && break;  # break if there are no arguments
+  case "$1" in
+    # If the enclosing script is called with --help option, print the help
+    # message and exit.  Scripts should put help messages in $help_message
+    --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
+      else printf "$help_message\n" 1>&2 ; fi;
+      exit 0 ;;
+    --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
+      exit 1 ;;
+    # If the first command-line argument begins with "--" (e.g. --foo-bar),
+    # then work out the variable name as $name, which will equal "foo_bar".
+    --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
+      # Next we test whether the variable in question is undefned-- if so it's
+      # an invalid option and we die.  Note: $0 evaluates to the name of the
+      # enclosing script.
+      # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
+      # is undefined.  We then have to wrap this test inside "eval" because
+      # foo_bar is itself inside a variable ($name).
+      eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
+
+      oldval="`eval echo \\$$name`";
+      # Work out whether we seem to be expecting a Boolean argument.
+      if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
+        was_bool=true;
+      else
+        was_bool=false;
+      fi
+
+      # Set the variable to the right value-- the escaped quotes make it work if
+      # the option had spaces, like --cmd "queue.pl -sync y"
+      eval $name=\"$2\";
+
+      # Check that Boolean-valued arguments are really Boolean.
+      if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
+        echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
+        exit 1;
+      fi
+      shift 2;
+      ;;
+  *) break;
+  esac
+done
+
+
+# Check for an empty argument to the --cmd option, which can easily occur as a
+# result of scripting errors.
+[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
+
+
+true; # so this script returns exit code 0.
diff --git a/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl
new file mode 100755
index 0000000..23992f2
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl
@@ -0,0 +1,27 @@
+#!/usr/bin/env perl
+# Copyright 2010-2011 Microsoft Corporation
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+while(<>){ 
+    @A = split(" ", $_);
+    @A > 1 || die "Invalid line in spk2utt file: $_";
+    $s = shift @A;
+    foreach $u ( @A ) {
+        print "$u $s\n";
+    }
+}
+
+
diff --git a/egs/alimeeting/sa-asr/utils/split_scp.pl b/egs/alimeeting/sa-asr/utils/split_scp.pl
new file mode 100755
index 0000000..0876dcb
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/split_scp.pl
@@ -0,0 +1,246 @@
+#!/usr/bin/env perl
+
+# Copyright 2010-2011 Microsoft Corporation
+
+# See ../../COPYING for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This program splits up any kind of .scp or archive-type file.
+# If there is no utt2spk option it will work on any text  file and
+# will split it up with an approximately equal number of lines in
+# each but.
+# With the --utt2spk option it will work on anything that has the
+# utterance-id as the first entry on each line; the utt2spk file is
+# of the form "utterance speaker" (on each line).
+# It splits it into equal size chunks as far as it can.  If you use the utt2spk
+# option it will make sure these chunks coincide with speaker boundaries.  In
+# this case, if there are more chunks than speakers (and in some other
+# circumstances), some of the resulting chunks will be empty and it will print
+# an error message and exit with nonzero status.
+# You will normally call this like:
+# split_scp.pl scp scp.1 scp.2 scp.3 ...
+# or
+# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
+# Note that you can use this script to split the utt2spk file itself,
+# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
+
+# You can also call the scripts like:
+# split_scp.pl -j 3 0 scp scp.0
+# [note: with this option, it assumes zero-based indexing of the split parts,
+# i.e. the second number must be 0 <= n < num-jobs.]
+
+use warnings;
+
+$num_jobs = 0;
+$job_id = 0;
+$utt2spk_file = "";
+$one_based = 0;
+
+for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
+    if ($ARGV[0] eq "-j") {
+        shift @ARGV;
+        $num_jobs = shift @ARGV;
+        $job_id = shift @ARGV;
+    }
+    if ($ARGV[0] =~ /--utt2spk=(.+)/) {
+        $utt2spk_file=$1;
+        shift;
+    }
+    if ($ARGV[0] eq '--one-based') {
+        $one_based = 1;
+        shift @ARGV;
+    }
+}
+
+if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
+                       $job_id - $one_based >= $num_jobs)) {
+  die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
+      ($one_based ? " --one-based" : "") . "'\n"
+}
+
+$one_based
+    and $job_id--;
+
+if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
+    die
+"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
+   or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
+ ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
+}
+
+$error = 0;
+$inscp = shift @ARGV;
+if ($num_jobs == 0) { # without -j option
+    @OUTPUTS = @ARGV;
+} else {
+    for ($j = 0; $j < $num_jobs; $j++) {
+        if ($j == $job_id) {
+            if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
+            else { push @OUTPUTS, "-"; }
+        } else {
+            push @OUTPUTS, "/dev/null";
+        }
+    }
+}
+
+if ($utt2spk_file ne "") {  # We have the --utt2spk option...
+    open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
+    while(<$u_fh>) {
+        @A = split;
+        @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
+        ($u,$s) = @A;
+        $utt2spk{$u} = $s;
+    }
+    close $u_fh;
+    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+    @spkrs = ();
+    while(<$i_fh>) {
+        @A = split;
+        if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
+        $u = $A[0];
+        $s = $utt2spk{$u};
+        defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
+        if(!defined $spk_count{$s}) {
+            push @spkrs, $s;
+            $spk_count{$s} = 0;
+            $spk_data{$s} = [];  # ref to new empty array.
+        }
+        $spk_count{$s}++;
+        push @{$spk_data{$s}}, $_;
+    }
+    # Now split as equally as possible ..
+    # First allocate spks to files by allocating an approximately
+    # equal number of speakers.
+    $numspks = @spkrs;  # number of speakers.
+    $numscps = @OUTPUTS; # number of output files.
+    if ($numspks < $numscps) {
+      die "$0: Refusing to split data because number of speakers $numspks " .
+          "is less than the number of output .scp files $numscps\n";
+    }
+    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+        $scparray[$scpidx] = []; # [] is array reference.
+    }
+    for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
+        $scpidx = int(($spkidx*$numscps) / $numspks);
+        $spk = $spkrs[$spkidx];
+        push @{$scparray[$scpidx]}, $spk;
+        $scpcount[$scpidx] += $spk_count{$spk};
+    }
+
+    # Now will try to reassign beginning + ending speakers
+    # to different scp's and see if it gets more balanced.
+    # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
+    # We can show that if considering changing just 2 scp's, we minimize
+    # this by minimizing the squared difference in sizes.  This is
+    # equivalent to minimizing the absolute difference in sizes.  This
+    # shows this method is bound to converge.
+
+    $changed = 1;
+    while($changed) {
+        $changed = 0;
+        for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+            # First try to reassign ending spk of this scp.
+            if($scpidx < $numscps-1) {
+                $sz = @{$scparray[$scpidx]};
+                if($sz > 0) {
+                    $spk = $scparray[$scpidx]->[$sz-1];
+                    $count = $spk_count{$spk};
+                    $nutt1 = $scpcount[$scpidx];
+                    $nutt2 = $scpcount[$scpidx+1];
+                    if( abs( ($nutt2+$count) - ($nutt1-$count))
+                        < abs($nutt2 - $nutt1))  { # Would decrease
+                        # size-diff by reassigning spk...
+                        $scpcount[$scpidx+1] += $count;
+                        $scpcount[$scpidx] -= $count;
+                        pop @{$scparray[$scpidx]};
+                        unshift @{$scparray[$scpidx+1]}, $spk;
+                        $changed = 1;
+                    }
+                }
+            }
+            if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
+                $spk = $scparray[$scpidx]->[0];
+                $count = $spk_count{$spk};
+                $nutt1 = $scpcount[$scpidx-1];
+                $nutt2 = $scpcount[$scpidx];
+                if( abs( ($nutt2-$count) - ($nutt1+$count))
+                    < abs($nutt2 - $nutt1))  { # Would decrease
+                    # size-diff by reassigning spk...
+                    $scpcount[$scpidx-1] += $count;
+                    $scpcount[$scpidx] -= $count;
+                    shift @{$scparray[$scpidx]};
+                    push @{$scparray[$scpidx-1]}, $spk;
+                    $changed = 1;
+                }
+            }
+        }
+    }
+    # Now print out the files...
+    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+        $scpfile = $OUTPUTS[$scpidx];
+        ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
+                         : open($f_fh, '>&', \*STDOUT)) ||
+            die "$0: Could not open scp file $scpfile for writing: $!\n";
+        $count = 0;
+        if(@{$scparray[$scpidx]} == 0) {
+            print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
+                         "$scpfile (too many splits and too few speakers?)\n";
+            $error = 1;
+        } else {
+            foreach $spk ( @{$scparray[$scpidx]} ) {
+                print $f_fh @{$spk_data{$spk}};
+                $count += $spk_count{$spk};
+            }
+            $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
+        }
+        close($f_fh);
+    }
+} else {
+   # This block is the "normal" case where there is no --utt2spk
+   # option and we just break into equal size chunks.
+
+    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+
+    $numscps = @OUTPUTS;  # size of array.
+    @F = ();
+    while(<$i_fh>) {
+        push @F, $_;
+    }
+    $numlines = @F;
+    if($numlines == 0) {
+        print STDERR "$0: error: empty input scp file $inscp\n";
+        $error = 1;
+    }
+    $linesperscp = int( $numlines / $numscps); # the "whole part"..
+    $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
+    $remainder = $numlines - ($linesperscp * $numscps);
+    ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
+    # [just doing int() rounds down].
+    $n = 0;
+    for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
+        $scpfile = $OUTPUTS[$scpidx];
+        ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
+                         : open($o_fh, '>&', \*STDOUT)) ||
+            die "$0: Could not open scp file $scpfile for writing: $!\n";
+        for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
+            print $o_fh $F[$n++];
+        }
+        close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
+    }
+    $n == $numlines || die "$n != $numlines [code error]";
+}
+
+exit ($error);
diff --git a/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl
new file mode 100755
index 0000000..6e0e438
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl
@@ -0,0 +1,38 @@
+#!/usr/bin/env perl
+# Copyright 2010-2011 Microsoft Corporation
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+# converts an utt2spk file to a spk2utt file.
+# Takes input from the stdin or from a file argument;
+# output goes to the standard out.
+
+if ( @ARGV > 1 ) {
+    die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
+}
+
+while(<>){ 
+    @A = split(" ", $_);
+    @A == 2 || die "Invalid line in utt2spk file: $_";
+    ($u,$s) = @A;
+    if(!$seen_spk{$s}) {
+        $seen_spk{$s} = 1;
+        push @spklist, $s;
+    }
+    push (@{$spk_hash{$s}}, "$u");
+}
+foreach $s (@spklist) {
+    $l = join(' ',@{$spk_hash{$s}});
+    print "$s $l\n";
+}
diff --git a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh
new file mode 100755
index 0000000..3eec443
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh
@@ -0,0 +1,404 @@
+#!/usr/bin/env bash
+
+cmd="$@"
+
+no_feats=false
+no_wav=false
+no_text=false
+no_spk_sort=false
+non_print=false
+
+
+function show_help
+{
+      echo "Usage: $0 [--no-feats] [--no-text] [--non-print] [--no-wav] [--no-spk-sort] <data-dir>"
+      echo "The --no-xxx options mean that the script does not require "
+      echo "xxx.scp to be present, but it will check it if it is present."
+      echo "--no-spk-sort means that the script does not require the utt2spk to be "
+      echo "sorted by the speaker-id in addition to being sorted by utterance-id."
+      echo "--non-print ignore the presence of non-printable characters."
+      echo "By default, utt2spk is expected to be sorted by both, which can be "
+      echo "achieved by making the speaker-id prefixes of the utterance-ids"
+      echo "e.g.: $0 data/train"
+}
+
+while [ $# -ne 0 ] ; do
+  case "$1" in
+    "--no-feats")
+      no_feats=true;
+      ;;
+    "--no-text")
+      no_text=true;
+      ;;
+    "--non-print")
+      non_print=true;
+      ;;
+    "--no-wav")
+      no_wav=true;
+      ;;
+    "--no-spk-sort")
+      no_spk_sort=true;
+      ;;
+    *)
+      if ! [ -z "$data" ] ; then
+        show_help;
+        exit 1
+      fi
+      data=$1
+      ;;
+  esac
+  shift
+done
+
+
+
+if [ ! -d $data ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+if [ -f $data/images.scp ]; then
+  cmd=${cmd/--no-wav/}  # remove --no-wav if supplied
+  image/validate_data_dir.sh $cmd
+  exit $?
+fi
+
+for f in spk2utt utt2spk; do
+  if [ ! -f $data/$f ]; then
+    echo "$0: no such file $f"
+    exit 1;
+  fi
+  if [ ! -s $data/$f ]; then
+    echo "$0: empty file $f"
+    exit 1;
+  fi
+done
+
+! cat $data/utt2spk | awk '{if (NF != 2) exit(1); }' && \
+  echo "$0: $data/utt2spk has wrong format." && exit;
+
+ns=$(wc -l < $data/spk2utt)
+if [ "$ns" == 1 ]; then
+  echo "$0: WARNING: you have only one speaker.  This probably a bad idea."
+  echo "   Search for the word 'bold' in http://kaldi-asr.org/doc/data_prep.html"
+  echo "   for more information."
+fi
+
+
+tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
+trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
+
+export LC_ALL=C
+
+function check_sorted_and_uniq {
+  ! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1;
+  ! awk '{print $1}' < $1 | sort -uC && echo "$0: file $1 is not sorted or has duplicates" && exit 1;
+}
+
+function partial_diff {
+  diff -U1 $1 $2 | (head -n 6; echo "..."; tail -n 6)
+  n1=`cat $1 | wc -l`
+  n2=`cat $2 | wc -l`
+  echo "[Lengths are $1=$n1 versus $2=$n2]"
+}
+
+check_sorted_and_uniq $data/utt2spk
+
+if ! $no_spk_sort; then
+  ! sort -k2 -C $data/utt2spk && \
+     echo "$0: utt2spk is not in sorted order when sorted first on speaker-id " && \
+     echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
+fi
+
+check_sorted_and_uniq $data/spk2utt
+
+! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
+     <(utils/spk2utt_to_utt2spk.pl $data/spk2utt)  && \
+   echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
+
+cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
+
+if [ ! -f $data/text ] && ! $no_text; then
+  echo "$0: no such file $data/text (if this is by design, specify --no-text)"
+  exit 1;
+fi
+
+num_utts=`cat $tmpdir/utts | wc -l`
+if ! $no_text; then
+  if ! $non_print; then
+    if locale -a | grep "C.UTF-8" >/dev/null; then
+      L=C.UTF-8
+    else
+      L=en_US.UTF-8
+    fi
+    n_non_print=$(LC_ALL="$L" grep -c '[^[:print:][:space:]]' $data/text) && \
+    echo "$0: text contains $n_non_print lines with non-printable characters" &&\
+    exit 1;
+  fi
+  utils/validate_text.pl $data/text || exit 1;
+  check_sorted_and_uniq $data/text
+  text_len=`cat $data/text | wc -l`
+  illegal_sym_list="<s> </s> #0"
+  for x in $illegal_sym_list; do
+    if grep -w "$x" $data/text > /dev/null; then
+      echo "$0: Error: in $data, text contains illegal symbol $x"
+      exit 1;
+    fi
+  done
+  awk '{print $1}' < $data/text > $tmpdir/utts.txt
+  if ! cmp -s $tmpdir/utts{,.txt}; then
+    echo "$0: Error: in $data, utterance lists extracted from utt2spk and text"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/utts{,.txt}
+    exit 1;
+  fi
+fi
+
+if [ -f $data/segments ] && [ ! -f $data/wav.scp ]; then
+  echo "$0: in directory $data, segments file exists but no wav.scp"
+  exit 1;
+fi
+
+
+if [ ! -f $data/wav.scp ] && ! $no_wav; then
+  echo "$0: no such file $data/wav.scp (if this is by design, specify --no-wav)"
+  exit 1;
+fi
+
+if [ -f $data/wav.scp ]; then
+  check_sorted_and_uniq $data/wav.scp
+
+  if grep -E -q '^\S+\s+~' $data/wav.scp; then
+    # note: it's not a good idea to have any kind of tilde in wav.scp, even if
+    # part of a command, as it would cause compatibility problems if run by
+    # other users, but this used to be not checked for so we let it slide unless
+    # it's something of the form "foo ~/foo.wav" (i.e. a plain file name) which
+    # would definitely cause problems as the fopen system call does not do
+    # tilde expansion.
+    echo "$0: Please do not use tilde (~) in your wav.scp."
+    exit 1;
+  fi
+
+  if [ -f $data/segments ]; then
+
+    check_sorted_and_uniq $data/segments
+    # We have a segments file -> interpret wav file as "recording-ids" not utterance-ids.
+    ! cat $data/segments | \
+      awk '{if (NF != 4 || $4 <= $3) { print "Bad line in segments file", $0; exit(1); }}' && \
+      echo "$0: badly formatted segments file" && exit 1;
+
+    segments_len=`cat $data/segments | wc -l`
+    if [ -f $data/text ]; then
+      ! cmp -s $tmpdir/utts <(awk '{print $1}' <$data/segments) && \
+        echo "$0: Utterance list differs between $data/utt2spk and $data/segments " && \
+        echo "$0: Lengths are $segments_len vs $num_utts" && \
+        exit 1
+    fi
+
+    cat $data/segments | awk '{print $2}' | sort | uniq > $tmpdir/recordings
+    awk '{print $1}' $data/wav.scp > $tmpdir/recordings.wav
+    if ! cmp -s $tmpdir/recordings{,.wav}; then
+      echo "$0: Error: in $data, recording-ids extracted from segments and wav.scp"
+      echo "$0: differ, partial diff is:"
+      partial_diff $tmpdir/recordings{,.wav}
+      exit 1;
+    fi
+    if [ -f $data/reco2file_and_channel ]; then
+      # this file is needed only for ctm scoring; it's indexed by recording-id.
+      check_sorted_and_uniq $data/reco2file_and_channel
+      ! cat $data/reco2file_and_channel | \
+        awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
+                if ( NF == 3 && $3 == "1" ) {
+                  warning_issued = 1;
+                } else {
+                  print "Bad line ", $0; exit 1;
+                }
+              }
+            }
+            END {
+              if (warning_issued == 1) {
+                print "The channel should be marked as A or B, not 1! You should change it ASAP! "
+              }
+            }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
+      cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/recordings.r2fc
+      if ! cmp -s $tmpdir/recordings{,.r2fc}; then
+        echo "$0: Error: in $data, recording-ids extracted from segments and reco2file_and_channel"
+        echo "$0: differ, partial diff is:"
+        partial_diff $tmpdir/recordings{,.r2fc}
+        exit 1;
+      fi
+    fi
+  else
+    # No segments file -> assume wav.scp indexed by utterance.
+    cat $data/wav.scp | awk '{print $1}' > $tmpdir/utts.wav
+    if ! cmp -s $tmpdir/utts{,.wav}; then
+      echo "$0: Error: in $data, utterance lists extracted from utt2spk and wav.scp"
+      echo "$0: differ, partial diff is:"
+      partial_diff $tmpdir/utts{,.wav}
+      exit 1;
+    fi
+
+    if [ -f $data/reco2file_and_channel ]; then
+      # this file is needed only for ctm scoring; it's indexed by recording-id.
+      check_sorted_and_uniq $data/reco2file_and_channel
+      ! cat $data/reco2file_and_channel | \
+        awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
+                if ( NF == 3 && $3 == "1" ) {
+                  warning_issued = 1;
+                } else {
+                  print "Bad line ", $0; exit 1;
+                }
+              }
+            }
+            END {
+              if (warning_issued == 1) {
+                print "The channel should be marked as A or B, not 1! You should change it ASAP! "
+              }
+            }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
+      cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/utts.r2fc
+      if ! cmp -s $tmpdir/utts{,.r2fc}; then
+        echo "$0: Error: in $data, utterance-ids extracted from segments and reco2file_and_channel"
+        echo "$0: differ, partial diff is:"
+        partial_diff $tmpdir/utts{,.r2fc}
+        exit 1;
+      fi
+    fi
+  fi
+fi
+
+if [ ! -f $data/feats.scp ] && ! $no_feats; then
+  echo "$0: no such file $data/feats.scp (if this is by design, specify --no-feats)"
+  exit 1;
+fi
+
+if [ -f $data/feats.scp ]; then
+  check_sorted_and_uniq $data/feats.scp
+  cat $data/feats.scp | awk '{print $1}' > $tmpdir/utts.feats
+  if ! cmp -s $tmpdir/utts{,.feats}; then
+    echo "$0: Error: in $data, utterance-ids extracted from utt2spk and features"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/utts{,.feats}
+    exit 1;
+  fi
+fi
+
+
+if [ -f $data/cmvn.scp ]; then
+  check_sorted_and_uniq $data/cmvn.scp
+  cat $data/cmvn.scp | awk '{print $1}' > $tmpdir/speakers.cmvn
+  cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+  if ! cmp -s $tmpdir/speakers{,.cmvn}; then
+    echo "$0: Error: in $data, speaker lists extracted from spk2utt and cmvn"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/speakers{,.cmvn}
+    exit 1;
+  fi
+fi
+
+if [ -f $data/spk2gender ]; then
+  check_sorted_and_uniq $data/spk2gender
+  ! cat $data/spk2gender | awk '{if (!((NF == 2 && ($2 == "m" || $2 == "f")))) exit 1; }' && \
+     echo "$0: Mal-formed spk2gender file" && exit 1;
+  cat $data/spk2gender | awk '{print $1}' > $tmpdir/speakers.spk2gender
+  cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+  if ! cmp -s $tmpdir/speakers{,.spk2gender}; then
+    echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2gender"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/speakers{,.spk2gender}
+    exit 1;
+  fi
+fi
+
+if [ -f $data/spk2warp ]; then
+  check_sorted_and_uniq $data/spk2warp
+  ! cat $data/spk2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
+     echo "$0: Mal-formed spk2warp file" && exit 1;
+  cat $data/spk2warp | awk '{print $1}' > $tmpdir/speakers.spk2warp
+  cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+  if ! cmp -s $tmpdir/speakers{,.spk2warp}; then
+    echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2warp"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/speakers{,.spk2warp}
+    exit 1;
+  fi
+fi
+
+if [ -f $data/utt2warp ]; then
+  check_sorted_and_uniq $data/utt2warp
+  ! cat $data/utt2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
+     echo "$0: Mal-formed utt2warp file" && exit 1;
+  cat $data/utt2warp | awk '{print $1}' > $tmpdir/utts.utt2warp
+  cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
+  if ! cmp -s $tmpdir/utts{,.utt2warp}; then
+    echo "$0: Error: in $data, utterance lists extracted from utt2spk and utt2warp"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/utts{,.utt2warp}
+    exit 1;
+  fi
+fi
+
+# check some optionally-required things
+for f in vad.scp utt2lang utt2uniq; do
+  if [ -f $data/$f ]; then
+    check_sorted_and_uniq $data/$f
+    if ! cmp -s <( awk '{print $1}' $data/utt2spk ) \
+      <( awk '{print $1}' $data/$f ); then
+      echo "$0: error: in $data, $f and utt2spk do not have identical utterance-id list"
+      exit 1;
+    fi
+  fi
+done
+
+
+if [ -f $data/utt2dur ]; then
+  check_sorted_and_uniq $data/utt2dur
+  cat $data/utt2dur | awk '{print $1}' > $tmpdir/utts.utt2dur
+  if ! cmp -s $tmpdir/utts{,.utt2dur}; then
+    echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2dur file"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/utts{,.utt2dur}
+    exit 1;
+  fi
+  cat $data/utt2dur | \
+    awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line utt2dur:" NR ":" $0; exit(1) }}' || exit 1
+fi
+
+if [ -f $data/utt2num_frames ]; then
+  check_sorted_and_uniq $data/utt2num_frames
+  cat $data/utt2num_frames | awk '{print $1}' > $tmpdir/utts.utt2num_frames
+  if ! cmp -s $tmpdir/utts{,.utt2num_frames}; then
+    echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2num_frames file"
+    echo "$0: differ, partial diff is:"
+    partial_diff $tmpdir/utts{,.utt2num_frames}
+    exit 1
+  fi
+  awk <$data/utt2num_frames '{
+    if (NF != 2 || !($2 > 0) || $2 != int($2)) {
+      print "Bad line utt2num_frames:" NR ":" $0
+      exit 1 } }' || exit 1
+fi
+
+if [ -f $data/reco2dur ]; then
+  check_sorted_and_uniq $data/reco2dur
+  cat $data/reco2dur | awk '{print $1}' > $tmpdir/recordings.reco2dur
+  if [ -f $tmpdir/recordings ]; then
+    if ! cmp -s $tmpdir/recordings{,.reco2dur}; then
+      echo "$0: Error: in $data, recording-ids extracted from segments and reco2dur file"
+      echo "$0: differ, partial diff is:"
+      partial_diff $tmpdir/recordings{,.reco2dur}
+    exit 1;
+    fi
+  else
+    if ! cmp -s $tmpdir/{utts,recordings.reco2dur}; then
+      echo "$0: Error: in $data, recording-ids extracted from wav.scp and reco2dur file"
+      echo "$0: differ, partial diff is:"
+      partial_diff $tmpdir/{utts,recordings.reco2dur}
+    exit 1;
+    fi
+  fi
+  cat $data/reco2dur | \
+    awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line : " $0; exit(1) }}' || exit 1
+fi
+
+
+echo "$0: Successfully validated data-directory $data"
diff --git a/egs/alimeeting/sa-asr/utils/validate_text.pl b/egs/alimeeting/sa-asr/utils/validate_text.pl
new file mode 100755
index 0000000..7f75cf1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/validate_text.pl
@@ -0,0 +1,136 @@
+#!/usr/bin/env perl
+#
+#===============================================================================
+# Copyright 2017  Johns Hopkins University (author: Yenda Trmal <jtrmal@gmail.com>)
+#                 Johns Hopkins University (author: Daniel Povey)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+#===============================================================================
+
+# validation script for data/<dataset>/text
+# to be called (preferably) from utils/validate_data_dir.sh
+use strict;
+use warnings;
+use utf8;
+use Fcntl qw< SEEK_SET >;
+
+# this function reads the opened file (supplied as a first
+# parameter) into an array of lines. For each
+# line, it tests whether it's a valid utf-8 compatible
+# line. If all lines are valid utf-8, it returns the lines
+# decoded as utf-8, otherwise it assumes the file's encoding
+# is one of those 1-byte encodings, such as ISO-8859-x
+# or Windows CP-X.
+# Please recall we do not really care about
+# the actually encoding, we just need to
+# make sure the length of the (decoded) string
+# is correct (to make the output formatting looking right).
+sub get_utf8_or_bytestream {
+  use Encode qw(decode encode);
+  my $is_utf_compatible = 1;
+  my @unicode_lines;
+  my @raw_lines;
+  my $raw_text;
+  my $lineno = 0;
+  my $file = shift;
+
+  while (<$file>) {
+    $raw_text = $_;
+    last unless $raw_text;
+    if ($is_utf_compatible) {
+      my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ;
+      $is_utf_compatible = $is_utf_compatible && defined($decoded_text);
+      push @unicode_lines, $decoded_text;
+    } else {
+      #print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n";
+      ;
+    }
+    push @raw_lines, $raw_text;
+    $lineno += 1;
+  }
+
+  if (!$is_utf_compatible) {
+    return (0, @raw_lines);
+  } else {
+    return (1, @unicode_lines);
+  }
+}
+
+# check if the given unicode string contain unicode whitespaces
+# other than the usual four: TAB, LF, CR and SPACE
+sub validate_utf8_whitespaces {
+  my $unicode_lines = shift;
+  use feature 'unicode_strings';
+  for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) {
+    my $current_line = $unicode_lines->[$i];
+    if ((substr $current_line, -1) ne "\n"){
+      print STDERR "$0: The current line (nr. $i) has invalid newline\n";
+      return 1;
+    }
+    my @A = split(" ", $current_line);
+    my $utt_id = $A[0];
+    # we replace TAB, LF, CR, and SPACE
+    # this is to simplify the test
+    if ($current_line =~ /\x{000d}/) {
+      print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n";
+      return 1;
+    }
+    $current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g;
+    if ($current_line =~/\s/) {
+      print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n";
+      return 1;
+    }
+  }
+  return 0;
+}
+
+# checks if the text in the file (supplied as the argument) is utf-8 compatible
+# if yes, checks if it contains only allowed whitespaces. If no, then does not
+# do anything. The function seeks to the original position in the file after
+# reading the text.
+sub check_allowed_whitespace {
+  my $file = shift;
+  my $filename = shift;
+  my $pos = tell($file);
+  (my $is_utf, my @lines) = get_utf8_or_bytestream($file);
+  seek($file, $pos, SEEK_SET);
+  if ($is_utf) {
+    my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines);
+    if ($has_invalid_whitespaces) {
+      print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n";
+      return 0;
+    }
+  }
+  return 1;
+}
+
+if(@ARGV != 1) {
+  die "Usage: validate_text.pl <text-file>\n" .
+      "e.g.: validate_text.pl data/train/text\n";
+}
+
+my $text = shift @ARGV;
+
+if (-z "$text") {
+  print STDERR "$0: ERROR: file '$text' is empty or does not exist\n";
+  exit 1;
+}
+
+if(!open(FILE, "<$text")) {
+  print STDERR "$0: ERROR: failed to open $text\n";
+  exit 1;
+}
+
+check_allowed_whitespace(\*FILE, $text) or exit 1;
+close(FILE);
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 4722602..c18472f 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -40,7 +40,6 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
 
 
 header_colors = '\033[95m'
@@ -91,8 +90,6 @@
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
 
         logging.info("asr_model: {}".format(asr_model))
         logging.info("asr_train_args: {}".format(asr_train_args))
@@ -111,7 +108,7 @@
         # 2. Build Language model
         if lm_train_config is not None:
             lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+                lm_train_config, lm_file, None, device
             )
             scorers["lm"] = lm.lm
 
@@ -141,6 +138,13 @@
             token_list=token_list,
             pre_beam_score_key=None if ctc_weight == 1.0 else "full",
         )
+
+        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+        for scorer in scorers.values():
+            if isinstance(scorer, torch.nn.Module):
+                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+        logging.info(f"Beam_search: {beam_search}")
+        logging.info(f"Decoding device={device}, dtype={dtype}")
 
         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
@@ -198,16 +202,7 @@
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
 
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
 
         # a. To device
         batch = to_device(batch, device=self.device)
@@ -355,6 +350,9 @@
     if ngpu > 1:
         raise NotImplementedError("only single GPU decoding is supported")
     
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+
     logging.basicConfig(
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -408,6 +406,7 @@
             data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
+            mc=True,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
@@ -452,7 +451,7 @@
                     
                     # Write the result to each file
                     ibest_writer["token"][key] = " ".join(token)
-                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                     ibest_writer["score"][key] = str(hyp.score)
                 
                 if text is not None:
@@ -463,6 +462,9 @@
                     asr_utils.print_progress(finish_count / file_count)
                     if writer is not None:
                         ibest_writer["text"][key] = text
+
+                logging.info("uttid: {}".format(key))
+                logging.info("text predictions: {}\n".format(text))
         return asr_result_list
     
     return _forward
@@ -637,4 +639,4 @@
 
 
 if __name__ == "__main__":
-    main()
+    main()
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e10ebf4..e165531 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -288,6 +288,9 @@
     if mode == "asr":
         from funasr.bin.asr_inference import inference
         return inference(**kwargs)
+    elif mode == "sa_asr":
+        from funasr.bin.sa_asr_inference import inference
+        return inference(**kwargs)
     elif mode == "uniasr":
         from funasr.bin.asr_inference_uniasr import inference
         return inference(**kwargs)
@@ -342,4 +345,4 @@
 
 
 if __name__ == "__main__":
-    main()
+    main()
\ No newline at end of file
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
index bba50da..c1e2cb2 100755
--- a/funasr/bin/asr_train.py
+++ b/funasr/bin/asr_train.py
@@ -2,6 +2,14 @@
 
 import os
 
+import logging
+
+logging.basicConfig(
+    level='INFO',
+    format=f"[{os.uname()[1].split('.')[0]}]"
+           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+)
+
 from funasr.tasks.asr import ASRTask
 
 
@@ -27,7 +35,8 @@
     args = parse_args()
 
     # setup local gpu_id
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+    if args.ngpu > 0:
+        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
 
     # DDP settings
     if args.ngpu > 1:
@@ -38,9 +47,9 @@
 
     # re-compute batch size: when dataset type is small
     if args.dataset_type == "small":
-        if args.batch_size is not None:
+        if args.batch_size is not None and args.ngpu > 0:
             args.batch_size = args.batch_size * args.ngpu
-        if args.batch_bins is not None:
+        if args.batch_bins is not None and args.ngpu > 0:
             args.batch_bins = args.batch_bins * args.ngpu
 
     main(args=args)
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
new file mode 100644
index 0000000..be63af1
--- /dev/null
+++ b/funasr/bin/sa_asr_inference.py
@@ -0,0 +1,674 @@
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
+from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.sa_asr import ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+
+class Speech2Text:
+    """Speech2Text class
+
+    Examples:
+        >>> import soundfile
+        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2text(audio)
+        [(text, token, token_int, hypothesis object), ...]
+
+    """
+
+    def __init__(
+            self,
+            asr_train_config: Union[Path, str] = None,
+            asr_model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            lm_train_config: Union[Path, str] = None,
+            lm_file: Union[Path, str] = None,
+            token_type: str = None,
+            bpemodel: str = None,
+            device: str = "cpu",
+            maxlenratio: float = 0.0,
+            minlenratio: float = 0.0,
+            batch_size: int = 1,
+            dtype: str = "float32",
+            beam_size: int = 20,
+            ctc_weight: float = 0.5,
+            lm_weight: float = 1.0,
+            ngram_weight: float = 0.9,
+            penalty: float = 0.0,
+            nbest: int = 1,
+            streaming: bool = False,
+            frontend_conf: dict = None,
+            **kwargs,
+    ):
+        assert check_argument_types()
+
+        # 1. Build ASR model
+        scorers = {}
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device
+        )
+        frontend = None
+
+        logging.info("asr_model: {}".format(asr_model))
+        logging.info("asr_train_args: {}".format(asr_train_args))
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+        decoder = asr_model.decoder
+
+        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        token_list = asr_model.token_list
+        scorers.update(
+            decoder=decoder,
+            ctc=ctc,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+
+        # 2. Build Language model
+        if lm_train_config is not None:
+            lm, lm_train_args = LMTask.build_model_from_file(
+                lm_train_config, lm_file, None, device
+            )
+            scorers["lm"] = lm.lm
+
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+
+        # 4. Build BeamSearch object
+        # transducer is not supported now
+        beam_search_transducer = None
+
+        weights = dict(
+            decoder=1.0 - ctc_weight,
+            ctc=ctc_weight,
+            lm=lm_weight,
+            ngram=ngram_weight,
+            length_bonus=penalty,
+        )
+        beam_search = BeamSearch(
+            beam_size=beam_size,
+            weights=weights,
+            scorers=scorers,
+            sos=asr_model.sos,
+            eos=asr_model.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+        )
+
+        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+        for scorer in scorers.values():
+            if isinstance(scorer, torch.nn.Module):
+                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+        logging.info(f"Beam_search: {beam_search}")
+        logging.info(f"Decoding device={device}, dtype={dtype}")
+
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        converter = TokenIDConverter(token_list=token_list)
+        logging.info(f"Text tokenizer: {tokenizer}")
+
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.converter = converter
+        self.tokenizer = tokenizer
+        self.beam_search = beam_search
+        self.beam_search_transducer = beam_search_transducer
+        self.maxlenratio = maxlenratio
+        self.minlenratio = minlenratio
+        self.device = device
+        self.dtype = dtype
+        self.nbest = nbest
+        self.frontend = frontend
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+    ) -> List[
+        Tuple[
+            Optional[str],
+            Optional[str],
+            List[str],
+            List[int],
+            Union[Hypothesis],
+        ]
+    ]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            text, text_id, token, token_int, hyp
+
+        """
+        assert check_argument_types()
+
+        # Input as audio signal
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+
+        if isinstance(profile, np.ndarray):
+            profile = torch.tensor(profile)
+
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+
+        # a. To device
+        batch = to_device(batch, device=self.device)
+
+        # b. Forward Encoder
+        asr_enc, _, spk_enc = self.asr_model.encode(**batch)
+        if isinstance(asr_enc, tuple):
+            asr_enc = asr_enc[0]
+        if isinstance(spk_enc, tuple):
+            spk_enc = spk_enc[0]
+        assert len(asr_enc) == 1, len(asr_enc)
+        assert len(spk_enc) == 1, len(spk_enc)
+
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+        )
+
+        nbest_hyps = nbest_hyps[: self.nbest]
+
+        results = []
+        for hyp in nbest_hyps:
+            assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1: last_pos]
+            else:
+                token_int = hyp.yseq[1: last_pos].tolist()
+
+            spk_weigths=torch.stack(hyp.spk_weigths, dim=0)
+
+            token_ori = self.converter.ids2tokens(token_int)
+            text_ori = self.tokenizer.tokens2text(token_ori)
+
+            text_ori_spklist = text_ori.split('$')
+            cur_index = 0
+            spk_choose = []
+            for i in range(len(text_ori_spklist)):
+                text_ori_split = text_ori_spklist[i]
+                n = len(text_ori_split)
+                spk_weights_local = spk_weigths[cur_index: cur_index + n]
+                cur_index = cur_index + n + 1
+                spk_weights_local = spk_weights_local.mean(dim=0)
+                spk_choose_local = spk_weights_local.argmax(-1)
+                spk_choose.append(spk_choose_local.item() + 1)
+
+            # remove blank symbol id, which is assumed to be 0
+            token_int = list(filter(lambda x: x != 0, token_int))
+
+            # Change integer-ids to tokens
+            token = self.converter.ids2tokens(token_int)
+
+            if self.tokenizer is not None:
+                text = self.tokenizer.tokens2text(token)
+            else:
+                text = None
+
+            text_spklist = text.split('$')
+            assert len(spk_choose) == len(text_spklist)
+
+            spk_list=[]
+            for i in range(len(text_spklist)):
+                text_split = text_spklist[i]
+                n = len(text_split)
+                spk_list.append(str(spk_choose[i]) * n)
+            
+            text_id = '$'.join(spk_list)
+            
+            assert len(text) == len(text_id)
+
+            results.append((text, text_id, token, token_int, hyp))
+
+        assert check_return_type(results)
+        return results
+
+def inference(
+        maxlenratio: float,
+        minlenratio: float,
+        batch_size: int,
+        beam_size: int,
+        ngpu: int,
+        ctc_weight: float,
+        lm_weight: float,
+        penalty: float,
+        log_level: Union[int, str],
+        data_path_and_name_and_type,
+        asr_train_config: Optional[str],
+        asr_model_file: Optional[str],
+        cmvn_file: Optional[str] = None,
+        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+        lm_train_config: Optional[str] = None,
+        lm_file: Optional[str] = None,
+        token_type: Optional[str] = None,
+        key_file: Optional[str] = None,
+        word_lm_train_config: Optional[str] = None,
+        bpemodel: Optional[str] = None,
+        allow_variable_data_keys: bool = False,
+        streaming: bool = False,
+        output_dir: Optional[str] = None,
+        dtype: str = "float32",
+        seed: int = 0,
+        ngram_weight: float = 0.9,
+        nbest: int = 1,
+        num_workers: int = 1,
+        **kwargs,
+):
+    inference_pipeline = inference_modelscope(
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        batch_size=batch_size,
+        beam_size=beam_size,
+        ngpu=ngpu,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        penalty=penalty,
+        log_level=log_level,
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        raw_inputs=raw_inputs,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        key_file=key_file,
+        word_lm_train_config=word_lm_train_config,
+        bpemodel=bpemodel,
+        allow_variable_data_keys=allow_variable_data_keys,
+        streaming=streaming,
+        output_dir=output_dir,
+        dtype=dtype,
+        seed=seed,
+        ngram_weight=ngram_weight,
+        nbest=nbest,
+        num_workers=num_workers,
+        **kwargs,
+    )
+    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+def inference_modelscope(
+    maxlenratio: float,
+    minlenratio: float,
+    batch_size: int,
+    beam_size: int,
+    ngpu: int,
+    ctc_weight: float,
+    lm_weight: float,
+    penalty: float,
+    log_level: Union[int, str],
+    # data_path_and_name_and_type,
+    asr_train_config: Optional[str],
+    asr_model_file: Optional[str],
+    cmvn_file: Optional[str] = None,
+    lm_train_config: Optional[str] = None,
+    lm_file: Optional[str] = None,
+    token_type: Optional[str] = None,
+    key_file: Optional[str] = None,
+    word_lm_train_config: Optional[str] = None,
+    bpemodel: Optional[str] = None,
+    allow_variable_data_keys: bool = False,
+    streaming: bool = False,
+    output_dir: Optional[str] = None,
+    dtype: str = "float32",
+    seed: int = 0,
+    ngram_weight: float = 0.9,
+    nbest: int = 1,
+    num_workers: int = 1,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+    
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+    
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+    
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        streaming=streaming,
+    )
+    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+    speech2text = Speech2Text(**speech2text_kwargs)
+    
+    def _forward(data_path_and_name_and_type,
+                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+                 output_dir_v2: Optional[str] = None,
+                 fs: dict = None,
+                 param_dict: dict = None,
+                 **kwargs,
+                 ):
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+        loader = ASRTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            fs=fs,
+            mc=True,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+        
+        finish_count = 0
+        file_count = 1
+        # 7 .Start for-loop
+        # FIXME(kamo): The output format should be discussed about
+        asr_result_list = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
+        
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+            # N-best list of (text, token, token_int, hyp_object)
+            try:
+                results = speech2text(**batch)
+            except TooShortUttError as e:
+                logging.warning(f"Utterance {keys} {e}")
+                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+                results = [[" ", ["sil"], [2], hyp]] * nbest
+            
+            # Only supporting batch_size==1
+            key = keys[0]
+            for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+                # Create a directory: outdir/{n}best_recog
+                if writer is not None:
+                    ibest_writer = writer[f"{n}best_recog"]
+                    
+                    # Write the result to each file
+                    ibest_writer["token"][key] = " ".join(token)
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["score"][key] = str(hyp.score)
+                    ibest_writer["text_id"][key] = text_id
+                
+                if text is not None:
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                    item = {'key': key, 'value': text_postprocessed}
+                    asr_result_list.append(item)
+                    finish_count += 1
+                    asr_utils.print_progress(finish_count / file_count)
+                    if writer is not None:
+                        ibest_writer["text"][key] = text
+
+                logging.info("uttid: {}".format(key))
+                logging.info("text predictions: {}".format(text))
+                logging.info("text_id predictions: {}\n".format(text_id))
+        return asr_result_list
+    
+    return _forward
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="ASR Decoding",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Note(kamo): Use '_' instead of '-' as separator.
+    # '-' is confusing if written in yaml.
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument(
+        "--gpuid_list",
+        type=str,
+        default="",
+        help="The visible gpus",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        required=False,
+        action="append",
+    )
+    group.add_argument("--raw_inputs", type=list, default=None)
+    # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument(
+        "--asr_train_config",
+        type=str,
+        help="ASR training configuration",
+    )
+    group.add_argument(
+        "--asr_model_file",
+        type=str,
+        help="ASR model parameter file",
+    )
+    group.add_argument(
+        "--cmvn_file",
+        type=str,
+        help="Global cmvn file",
+    )
+    group.add_argument(
+        "--lm_train_config",
+        type=str,
+        help="LM training configuration",
+    )
+    group.add_argument(
+        "--lm_file",
+        type=str,
+        help="LM parameter file",
+    )
+    group.add_argument(
+        "--word_lm_train_config",
+        type=str,
+        help="Word LM training configuration",
+    )
+    group.add_argument(
+        "--word_lm_file",
+        type=str,
+        help="Word LM parameter file",
+    )
+    group.add_argument(
+        "--ngram_file",
+        type=str,
+        help="N-gram parameter file",
+    )
+    group.add_argument(
+        "--model_tag",
+        type=str,
+        help="Pretrained model tag. If specify this option, *_train_config and "
+             "*_file will be overwritten",
+    )
+
+    group = parser.add_argument_group("Beam-search related")
+    group.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+    group.add_argument(
+        "--maxlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain max output length. "
+             "If maxlenratio=0.0 (default), it uses a end-detect "
+             "function "
+             "to automatically find maximum hypothesis lengths."
+             "If maxlenratio<0.0, its absolute value is interpreted"
+             "as a constant max output length",
+    )
+    group.add_argument(
+        "--minlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain min output length",
+    )
+    group.add_argument(
+        "--ctc_weight",
+        type=float,
+        default=0.5,
+        help="CTC weight in joint decoding",
+    )
+    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+    group.add_argument("--streaming", type=str2bool, default=False)
+
+    group = parser.add_argument_group("Text converter related")
+    group.add_argument(
+        "--token_type",
+        type=str_or_none,
+        default=None,
+        choices=["char", "bpe", None],
+        help="The token type for ASR model. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model path of sentencepiece. "
+             "If not given, refers from the training args",
+    )
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+    inference(**kwargs)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
new file mode 100755
index 0000000..c7c7c42
--- /dev/null
+++ b/funasr/bin/sa_asr_train.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+
+import os
+
+import logging
+
+logging.basicConfig(
+    level='INFO',
+    format=f"[{os.uname()[1].split('.')[0]}]"
+           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+)
+
+from funasr.tasks.sa_asr import ASRTask
+
+
+# for ASR Training
+def parse_args():
+    parser = ASRTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    # for ASR Training
+    ASRTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # setup local gpu_id
+    if args.ngpu > 0:
+        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    # re-compute batch size: when dataset type is small
+    if args.dataset_type == "small":
+        if args.batch_size is not None and args.ngpu > 0:
+            args.batch_size = args.batch_size * args.ngpu
+        if args.batch_bins is not None and args.ngpu > 0:
+            args.batch_bins = args.batch_bins * args.ngpu
+
+    main(args=args)
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index dc872b0..d757f7f 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -46,13 +46,15 @@
         if self.normalize:
             # soundfile.read normalizes data to [-1,1] if dtype is not given
             array, rate = librosa.load(
-                wav, sr=self.dest_sample_rate, mono=not self.always_2d
+                wav, sr=self.dest_sample_rate, mono=self.always_2d
             )
         else:
             array, rate = librosa.load(
-                wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype
+                wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
             )
 
+        if array.ndim==2:
+            array=array.transpose((1, 0))
         return rate, array
 
     def get_path(self, key):
diff --git a/funasr/losses/nll_loss.py b/funasr/losses/nll_loss.py
new file mode 100644
index 0000000..7e4e294
--- /dev/null
+++ b/funasr/losses/nll_loss.py
@@ -0,0 +1,47 @@
+import torch
+from torch import nn
+
+class NllLoss(nn.Module):
+    """Nll loss.
+
+    :param int size: the number of class
+    :param int padding_idx: ignored class id
+    :param bool normalize_length: normalize loss by sequence length if True
+    :param torch.nn.Module criterion: loss function
+    """
+
+    def __init__(
+        self,
+        size,
+        padding_idx,
+        normalize_length=False,
+        criterion=nn.NLLLoss(reduction='none'),
+    ):
+        """Construct an LabelSmoothingLoss object."""
+        super(NllLoss, self).__init__()
+        self.criterion = criterion
+        self.padding_idx = padding_idx
+        self.size = size
+        self.true_dist = None
+        self.normalize_length = normalize_length
+
+    def forward(self, x, target):
+        """Compute loss between x and target.
+
+        :param torch.Tensor x: prediction (batch, seqlen, class)
+        :param torch.Tensor target:
+            target signal masked with self.padding_id (batch, seqlen)
+        :return: scalar float value
+        :rtype torch.Tensor
+        """
+        assert x.size(2) == self.size
+        batch_size = x.size(0)
+        x = x.view(-1, self.size)
+        target = target.view(-1)
+        with torch.no_grad():
+            ignore = target == self.padding_idx  # (B,)
+            total = len(target) - ignore.sum().item()
+            target = target.masked_fill(ignore, 0)  # avoid -1 index
+        kl = self.criterion(x , target)
+        denom = total if self.normalize_length else batch_size
+        return kl.masked_fill(ignore, 0).sum() / denom
diff --git a/funasr/models/decoder/decoder_layer_sa_asr.py b/funasr/models/decoder/decoder_layer_sa_asr.py
new file mode 100644
index 0000000..80afc51
--- /dev/null
+++ b/funasr/models/decoder/decoder_layer_sa_asr.py
@@ -0,0 +1,169 @@
+import torch
+from torch import nn
+
+from funasr.modules.layer_norm  import LayerNorm
+
+
+class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
+        
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            # compute only the last frame query keeping dim: max_time_out -> 1
+            assert cache.shape == (
+                tgt.shape[0],
+                tgt.shape[1] - 1,
+                self.size,
+            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        if self.concat_after:
+            tgt_concat = torch.cat(
+                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+            )
+            x = residual + self.concat_linear1(tgt_concat)
+        else:
+            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+        if not self.normalize_before:
+            x = self.norm1(x)
+        z = x
+        
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
+
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, skip), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x = residual + self.dropout(skip)
+        if not self.normalize_before:
+            x = self.norm1(x)
+        
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+            
+        return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
+
+class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
+    
+    def __init__(
+        self,
+        size,
+        d_size,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
+        self.size = size
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        self.spk_linear = nn.Linear(d_size, size, bias=False)
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
+        
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        x = tgt_q
+        if self.normalize_before:
+            x = self.norm2(x)
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+        if not self.normalize_before:
+            x = self.norm2(x)
+        residual = x
+
+        if dn!=None:
+            x = x + self.spk_linear(dn)
+        if self.normalize_before:
+            x = self.norm3(x)
+        
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm3(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, tgt_mask, memory, memory_mask
+
+
+
diff --git a/funasr/models/decoder/transformer_decoder_sa_asr.py b/funasr/models/decoder/transformer_decoder_sa_asr.py
new file mode 100644
index 0000000..949f9c8
--- /dev/null
+++ b/funasr/models/decoder/transformer_decoder_sa_asr.py
@@ -0,0 +1,291 @@
+from typing import Any
+from typing import List
+from typing import Sequence
+from typing import Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.attention import CosineDistanceAttention
+from funasr.models.decoder.transformer_decoder import DecoderLayer
+from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
+from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
+from funasr.modules.dynamic_conv import DynamicConvolution
+from funasr.modules.dynamic_conv2d import DynamicConvolution2D
+from funasr.modules.embedding import PositionalEncoding
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.lightconv import LightweightConvolution
+from funasr.modules.lightconv2d import LightweightConvolution2D
+from funasr.modules.mask import subsequent_mask
+from funasr.modules.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.models.decoder.abs_decoder import AbsDecoder
+
+class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
+    
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        spker_embedding_dim: int = 256,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        input_layer: str = "embed",
+        use_asr_output_layer: bool = True,
+        use_spk_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        attention_dim = encoder_output_size
+
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_asr_output_layer:
+            self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.asr_output_layer = None
+
+        if use_spk_output_layer:
+            self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
+        else:
+            self.spk_output_layer = None
+
+        self.cos_distance_att = CosineDistanceAttention()
+
+        self.decoder1 = None
+        self.decoder2 = None
+        self.decoder3 = None
+        self.decoder4 = None
+
+    def forward(
+        self,
+        asr_hs_pad: torch.Tensor,
+        spk_hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        profile: torch.Tensor,
+        profile_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        
+        tgt = ys_in_pad
+        # tgt_mask: (B, 1, L)
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+        # m: (1, L, L)
+        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+        # tgt_mask: (B, L, L)
+        tgt_mask = tgt_mask & m
+
+        asr_memory = asr_hs_pad
+        spk_memory = spk_hs_pad
+        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
+        # Spk decoder
+        x = self.embed(tgt)
+
+        x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
+            x, tgt_mask, asr_memory, spk_memory, memory_mask
+        )
+        x, tgt_mask, spk_memory, memory_mask = self.decoder2(
+            x, tgt_mask, spk_memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.spk_output_layer is not None:
+            x = self.spk_output_layer(x)
+        dn, weights = self.cos_distance_att(x, profile, profile_lens)
+        # Asr decoder
+        x, tgt_mask, asr_memory, memory_mask = self.decoder3(
+            z, tgt_mask, asr_memory, memory_mask, dn
+        )
+        x, tgt_mask, asr_memory, memory_mask = self.decoder4(
+            x, tgt_mask, asr_memory, memory_mask
+        )
+
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.asr_output_layer is not None:
+            x = self.asr_output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, weights, olens
+
+
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        asr_memory: torch.Tensor,
+        spk_memory: torch.Tensor,
+        profile: torch.Tensor,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        
+        x = self.embed(tgt)
+
+        if cache is None:
+            cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
+        new_cache = []
+        x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
+                x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
+        )
+        new_cache.append(x)
+        for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
+            x, tgt_mask, spk_memory, _ = decoder(
+                x, tgt_mask, spk_memory, None, cache=c
+            )
+            new_cache.append(x)
+        if self.normalize_before:
+            x = self.after_norm(x)
+        else:
+            x = x
+        if self.spk_output_layer is not None:
+            x = self.spk_output_layer(x)
+        dn, weights = self.cos_distance_att(x, profile, None)
+
+        x, tgt_mask, asr_memory, _ = self.decoder3(
+            z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
+        )
+        new_cache.append(x)
+
+        for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
+            x, tgt_mask, asr_memory, _ = decoder(
+                x, tgt_mask, asr_memory, None, cache=c
+            )
+            new_cache.append(x)
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.asr_output_layer is not None:
+            y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
+
+        return y, weights, new_cache
+
+    def score(self, ys, state, asr_enc, spk_enc, profile):
+        """Score."""
+        ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
+        logp, weights, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
+        )
+        return logp.squeeze(0), weights.squeeze(), state
+
+class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        spker_embedding_dim: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        asr_num_blocks: int = 6,
+        spk_num_blocks: int = 3,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_asr_output_layer: bool = True,
+        use_spk_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+    ):
+        assert check_argument_types()
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            spker_embedding_dim=spker_embedding_dim,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_asr_output_layer=use_asr_output_layer,
+            use_spk_output_layer=use_spk_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+
+        self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
+            attention_dim,
+            MultiHeadedAttention(
+                attention_heads, attention_dim, self_attention_dropout_rate
+            ),
+            MultiHeadedAttention(
+                attention_heads, attention_dim, src_attention_dropout_rate
+            ),
+            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+            dropout_rate,
+            normalize_before,
+            concat_after,
+        )
+        self.decoder2 = repeat(
+            spk_num_blocks - 1,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        
+        
+        self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
+            attention_dim,
+            spker_embedding_dim,
+            MultiHeadedAttention(
+                attention_heads, attention_dim, src_attention_dropout_rate
+            ),
+            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+            dropout_rate,
+            normalize_before,
+            concat_after,
+        )
+        self.decoder4 = repeat(
+            asr_num_blocks - 1,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
new file mode 100644
index 0000000..0d4097e
--- /dev/null
+++ b/funasr/models/e2e_sa_asr.py
@@ -0,0 +1,521 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+    LabelSmoothingLoss,  # noqa: H301
+)
+from funasr.losses.nll_loss import NllLoss
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.e2e_asr_common import ErrorCalculator
+from funasr.modules.nets_utils import th_accuracy
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+class ESPnetASRModel(AbsESPnetModel):
+    """CTC-attention hybrid Encoder-Decoder model"""
+
+    def __init__(
+            self,
+            vocab_size: int,
+            max_spk_num: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
+            asr_encoder: AbsEncoder,
+            spk_encoder: torch.nn.Module,
+            postencoder: Optional[AbsPostEncoder],
+            decoder: AbsDecoder,
+            ctc: CTC,
+            spk_weight: float = 0.5,
+            ctc_weight: float = 0.5,
+            interctc_weight: float = 0.0,
+            ignore_id: int = -1,
+            lsm_weight: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
+            extract_feats_in_collect_stats: bool = True,
+    ):
+        assert check_argument_types()
+        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+        assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+        super().__init__()
+        # note that eos is the same as sos (equivalent ID)
+        self.blank_id = 0
+        self.sos = 1
+        self.eos = 2
+        self.vocab_size = vocab_size
+        self.max_spk_num=max_spk_num
+        self.ignore_id = ignore_id
+        self.spk_weight = spk_weight
+        self.ctc_weight = ctc_weight
+        self.interctc_weight = interctc_weight
+        self.token_list = token_list.copy()
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.preencoder = preencoder
+        self.postencoder = postencoder
+        self.asr_encoder = asr_encoder
+        self.spk_encoder = spk_encoder
+
+        if not hasattr(self.asr_encoder, "interctc_use_conditioning"):
+            self.asr_encoder.interctc_use_conditioning = False
+        if self.asr_encoder.interctc_use_conditioning:
+            self.asr_encoder.conditioning_layer = torch.nn.Linear(
+                vocab_size, self.asr_encoder.output_size()
+            )
+
+        self.error_calculator = None
+
+
+        # we set self.decoder = None in the CTC mode since
+        # self.decoder parameters were never used and PyTorch complained
+        # and threw an Exception in the multi-GPU experiment.
+        # thanks Jeff Farris for pointing out the issue.
+        if ctc_weight == 1.0:
+            self.decoder = None
+        else:
+            self.decoder = decoder
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        self.criterion_spk = NllLoss(
+            size=max_spk_num,
+            padding_idx=ignore_id,
+            normalize_length=length_normalized_loss,
+        )
+
+        if report_cer or report_wer:
+            self.error_calculator = ErrorCalculator(
+                token_list, sym_space, sym_blank, report_cer, report_wer
+            )
+
+        if ctc_weight == 0.0:
+            self.ctc = None
+        else:
+            self.ctc = ctc
+
+        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+            profile: torch.Tensor,
+            profile_lengths: torch.Tensor,
+            text_id: torch.Tensor,
+            text_id_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Decoder + Calc loss
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+            text: (Batch, Length)
+            text_lengths: (Batch,)
+            profile: (Batch, Length, Dim)
+            profile_lengths: (Batch,)
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        # Check that batch_size is unified
+        assert (
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+        batch_size = speech.shape[0]
+
+        # for data-parallel
+        text = text[:, : text_lengths.max()]
+
+        # 1. Encoder
+        asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(asr_encoder_out, tuple):
+            intermediate_outs = asr_encoder_out[1]
+            asr_encoder_out = asr_encoder_out[0]
+
+        loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None
+        loss_ctc, cer_ctc = None, None
+        stats = dict()
+
+        # 1. CTC branch
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                asr_encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+
+        # Intermediate CTC (optional)
+        loss_interctc = 0.0
+        if self.interctc_weight != 0.0 and intermediate_outs is not None:
+            for layer_idx, intermediate_out in intermediate_outs:
+                # we assume intermediate_out has the same length & padding
+                # as those of encoder_out
+                loss_ic, cer_ic = self._calc_ctc_loss(
+                    intermediate_out, encoder_out_lens, text, text_lengths
+                )
+                loss_interctc = loss_interctc + loss_ic
+
+                # Collect Intermedaite CTC stats
+                stats["loss_interctc_layer{}".format(layer_idx)] = (
+                    loss_ic.detach() if loss_ic is not None else None
+                )
+                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+            loss_interctc = loss_interctc / len(intermediate_outs)
+
+            # calculate whole encoder loss
+            loss_ctc = (
+                               1 - self.interctc_weight
+                       ) * loss_ctc + self.interctc_weight * loss_interctc
+
+
+        # 2b. Attention decoder branch
+        if self.ctc_weight != 1.0:
+            loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
+                asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
+            )
+
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss_asr = loss_att
+        elif self.ctc_weight == 1.0:
+            loss_asr = loss_ctc
+        else:
+            loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+        if self.spk_weight == 0.0:
+            loss = loss_asr
+        else:
+            loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
+
+
+        stats = dict(
+            loss=loss.detach(),
+            loss_asr=loss_asr.detach(),
+            loss_att=loss_att.detach() if loss_att is not None else None,
+            loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
+            loss_spk=loss_spk.detach() if loss_spk is not None else None,
+            acc=acc_att,
+            acc_spk=acc_spk,
+            cer=cer_att,
+            wer=wer_att,
+            cer_ctc=cer_ctc,
+        )
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def collect_feats(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+    ) -> Dict[str, torch.Tensor]:
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+            feats, feats_lengths = speech, speech_lengths
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            feats_raw = feats.clone()
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        # 4. Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.asr_encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.asr_encoder(
+                feats, feats_lengths, ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
+        # import ipdb;ipdb.set_trace()
+        if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
+            encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
+        else:
+            encoder_out_spk=encoder_out_spk_ori
+        # Post-encoder, e.g. NLU
+        if self.postencoder is not None:
+            encoder_out, encoder_out_lens = self.postencoder(
+                encoder_out, encoder_out_lens
+            )
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+        assert encoder_out_spk.size(0) == speech.size(0), (
+            encoder_out_spk.size(),
+            speech.size(0),
+        )
+
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
+        return encoder_out, encoder_out_lens, encoder_out_spk
+
+    def _extract_feats(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            # Frontend
+            #  e.g. STFT and Feature extract
+            #       data_loader may send time-domain signal in this case
+            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            # No frontend and no feature extract
+            feats, feats_lengths = speech, speech_lengths
+        return feats, feats_lengths
+
+    def nll(
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute negative log likelihood(nll) from transformer-decoder
+
+        Normally, this function is called in batchify_nll.
+
+        Args:
+            encoder_out: (Batch, Length, Dim)
+            encoder_out_lens: (Batch,)
+            ys_pad: (Batch, Length)
+            ys_pad_lens: (Batch,)
+        """
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+
+        # 1. Forward decoder
+        decoder_out, _ = self.decoder(
+            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+        )  # [batch, seqlen, dim]
+        batch_size = decoder_out.size(0)
+        decoder_num_class = decoder_out.size(2)
+        # nll: negative log-likelihood
+        nll = torch.nn.functional.cross_entropy(
+            decoder_out.view(-1, decoder_num_class),
+            ys_out_pad.view(-1),
+            ignore_index=self.ignore_id,
+            reduction="none",
+        )
+        nll = nll.view(batch_size, -1)
+        nll = nll.sum(dim=1)
+        assert nll.size(0) == batch_size
+        return nll
+
+    def batchify_nll(
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+            batch_size: int = 100,
+    ):
+        """Compute negative log likelihood(nll) from transformer-decoder
+
+        To avoid OOM, this fuction seperate the input into batches.
+        Then call nll for each batch and combine and return results.
+        Args:
+            encoder_out: (Batch, Length, Dim)
+            encoder_out_lens: (Batch,)
+            ys_pad: (Batch, Length)
+            ys_pad_lens: (Batch,)
+            batch_size: int, samples each batch contain when computing nll,
+                        you may change this to avoid OOM or increase
+                        GPU memory usage
+        """
+        total_num = encoder_out.size(0)
+        if total_num <= batch_size:
+            nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+        else:
+            nll = []
+            start_idx = 0
+            while True:
+                end_idx = min(start_idx + batch_size, total_num)
+                batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
+                batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
+                batch_ys_pad = ys_pad[start_idx:end_idx, :]
+                batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
+                batch_nll = self.nll(
+                    batch_encoder_out,
+                    batch_encoder_out_lens,
+                    batch_ys_pad,
+                    batch_ys_pad_lens,
+                )
+                nll.append(batch_nll)
+                start_idx = end_idx
+                if start_idx == total_num:
+                    break
+            nll = torch.cat(nll)
+        assert nll.size(0) == total_num
+        return nll
+
+    def _calc_att_loss(
+            self,
+            asr_encoder_out: torch.Tensor,
+            spk_encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+            profile: torch.Tensor,
+            profile_lens: torch.Tensor,
+            text_id: torch.Tensor,
+            text_id_lengths: torch.Tensor
+    ):
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+
+        # 1. Forward decoder
+        decoder_out, weights_no_pad, _ = self.decoder(
+            asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
+        )
+
+        spk_num_no_pad=weights_no_pad.size(-1)
+        pad=(0,self.max_spk_num-spk_num_no_pad)
+        weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
+
+        # pre_id=weights.argmax(-1)
+        # pre_text=decoder_out.argmax(-1)
+        # id_mask=(pre_id==text_id).to(dtype=text_id.dtype)
+        # pre_text_mask=pre_text*id_mask+1-id_mask #鐩稿悓鐨勫湴鏂逛笉鍙橈紝涓嶅悓鐨勫湴鏂硅涓�1(<unk>)
+        # padding_mask= ys_out_pad != self.ignore_id
+        # numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask))
+        # denominator = torch.sum(padding_mask)
+        # sd_acc = float(numerator) / float(denominator)
+
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_out_pad)
+        loss_spk = self.criterion_spk(torch.log(weights), text_id)
+
+        acc_spk= th_accuracy(
+            weights.view(-1, self.max_spk_num),
+            text_id,
+            ignore_label=self.ignore_id,
+        )
+        acc_att = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_out_pad,
+            ignore_label=self.ignore_id,
+        )
+
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+        return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
+
+    def _calc_ctc_loss(
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 9671fe9..2e1b0c4 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -38,6 +38,7 @@
             htk: bool = False,
             frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
             apply_stft: bool = True,
+            use_channel: int = None,
     ):
         assert check_argument_types()
         super().__init__()
@@ -77,6 +78,7 @@
         )
         self.n_mels = n_mels
         self.frontend_type = "default"
+        self.use_channel = use_channel
 
     def output_size(self) -> int:
         return self.n_mels
@@ -100,9 +102,12 @@
         if input_stft.dim() == 4:
             # h: (B, T, C, F) -> h: (B, T, F)
             if self.training:
-                # Select 1ch randomly
-                ch = np.random.randint(input_stft.size(2))
-                input_stft = input_stft[:, :, ch, :]
+                if self.use_channel == None:
+                    input_stft = input_stft[:, :, 0, :]
+                else:
+                    # Select 1ch randomly
+                    ch = np.random.randint(input_stft.size(2))
+                    input_stft = input_stft[:, :, ch, :]
             else:
                 # Use the first channel
                 input_stft = input_stft[:, :, 0, :]
diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py
index 8f85de9..39d94be 100644
--- a/funasr/models/pooling/statistic_pooling.py
+++ b/funasr/models/pooling/statistic_pooling.py
@@ -83,9 +83,9 @@
     num_chunk = int(math.ceil(tt / pooling_stride))
     pad = pooling_size // 2
     if len(xs_pad.shape) == 4:
-        features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
+        features = F.pad(xs_pad, (0, 0, pad, pad), "replicate")
     else:
-        features = F.pad(xs_pad, (pad, pad), "reflect")
+        features = F.pad(xs_pad, (pad, pad), "replicate")
     stat_list = []
 
     for i in range(num_chunk):
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 6202079..fcb3ed4 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -13,6 +13,9 @@
 from torch import nn
 from typing import Optional, Tuple
 
+import torch.nn.functional as F
+from funasr.modules.nets_utils import make_pad_mask
+
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
 
@@ -959,3 +962,37 @@
         q, k, v = self.forward_qkv(query, key, value)
         scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
         return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
+
+class CosineDistanceAttention(nn.Module):
+    """ Compute Cosine Distance between spk decoder output and speaker profile 
+    Args:
+        profile_path: speaker profile file path (.npy file)
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, spk_decoder_out, profile, profile_lens=None):
+        """
+        Args:
+            spk_decoder_out(torch.Tensor):(B, L, D)
+            spk_profiles(torch.Tensor):(B, N, D)
+        """
+        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
+        if profile_lens is not None:
+            
+            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+            )
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
+        else:
+            x = x[:, -1:, :, :]
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
+        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
+
+        return spk_embedding, weights
diff --git a/funasr/modules/beam_search/beam_search_sa_asr.py b/funasr/modules/beam_search/beam_search_sa_asr.py
new file mode 100755
index 0000000..b2b6833
--- /dev/null
+++ b/funasr/modules/beam_search/beam_search_sa_asr.py
@@ -0,0 +1,525 @@
+"""Beam search module."""
+
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.modules.e2e_asr_common import end_detect
+from funasr.modules.scorers.scorer_interface import PartialScorerInterface
+from funasr.modules.scorers.scorer_interface import ScorerInterface
+from funasr.models.decoder.abs_decoder import AbsDecoder
+
+
+class Hypothesis(NamedTuple):
+    """Hypothesis data type."""
+
+    yseq: torch.Tensor
+    spk_weigths : List
+    score: Union[float, torch.Tensor] = 0
+    scores: Dict[str, Union[float, torch.Tensor]] = dict()
+    states: Dict[str, Any] = dict()
+
+    def asdict(self) -> dict:
+        """Convert data to JSON-friendly dict."""
+        return self._replace(
+            yseq=self.yseq.tolist(),
+            score=float(self.score),
+            scores={k: float(v) for k, v in self.scores.items()},
+        )._asdict()
+
+
+class BeamSearch(torch.nn.Module):
+    """Beam search implementation."""
+
+    def __init__(
+        self,
+        scorers: Dict[str, ScorerInterface],
+        weights: Dict[str, float],
+        beam_size: int,
+        vocab_size: int,
+        sos: int,
+        eos: int,
+        token_list: List[str] = None,
+        pre_beam_ratio: float = 1.5,
+        pre_beam_score_key: str = None,
+    ):
+        """Initialize beam search.
+
+        Args:
+            scorers (dict[str, ScorerInterface]): Dict of decoder modules
+                e.g., Decoder, CTCPrefixScorer, LM
+                The scorer will be ignored if it is `None`
+            weights (dict[str, float]): Dict of weights for each scorers
+                The scorer will be ignored if its weight is 0
+            beam_size (int): The number of hypotheses kept during search
+            vocab_size (int): The number of vocabulary
+            sos (int): Start of sequence id
+            eos (int): End of sequence id
+            token_list (list[str]): List of tokens for debug log
+            pre_beam_score_key (str): key of scores to perform pre-beam search
+            pre_beam_ratio (float): beam size in the pre-beam search
+                will be `int(pre_beam_ratio * beam_size)`
+
+        """
+        super().__init__()
+        # set scorers
+        self.weights = weights
+        self.scorers = dict()
+        self.full_scorers = dict()
+        self.part_scorers = dict()
+        # this module dict is required for recursive cast
+        # `self.to(device, dtype)` in `recog.py`
+        self.nn_dict = torch.nn.ModuleDict()
+        for k, v in scorers.items():
+            w = weights.get(k, 0)
+            if w == 0 or v is None:
+                continue
+            assert isinstance(
+                v, ScorerInterface
+            ), f"{k} ({type(v)}) does not implement ScorerInterface"
+            self.scorers[k] = v
+            if isinstance(v, PartialScorerInterface):
+                self.part_scorers[k] = v
+            else:
+                self.full_scorers[k] = v
+            if isinstance(v, torch.nn.Module):
+                self.nn_dict[k] = v
+
+        # set configurations
+        self.sos = sos
+        self.eos = eos
+        self.token_list = token_list
+        self.pre_beam_size = int(pre_beam_ratio * beam_size)
+        self.beam_size = beam_size
+        self.n_vocab = vocab_size
+        if (
+            pre_beam_score_key is not None
+            and pre_beam_score_key != "full"
+            and pre_beam_score_key not in self.full_scorers
+        ):
+            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+        self.pre_beam_score_key = pre_beam_score_key
+        self.do_pre_beam = (
+            self.pre_beam_score_key is not None
+            and self.pre_beam_size < self.n_vocab
+            and len(self.part_scorers) > 0
+        )
+
+    def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+        """Get an initial hypothesis data.
+
+        Args:
+            x (torch.Tensor): The encoder output feature
+
+        Returns:
+            Hypothesis: The initial hypothesis.
+
+        """
+        init_states = dict()
+        init_scores = dict()
+        for k, d in self.scorers.items():
+            init_states[k] = d.init_state(x)
+            init_scores[k] = 0.0
+        return [
+            Hypothesis(
+                score=0.0,
+                scores=init_scores,
+                states=init_states,
+                yseq=torch.tensor([self.sos], device=x.device),
+                spk_weigths=[],
+            )
+        ]
+
+    @staticmethod
+    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+        """Append new token to prefix tokens.
+
+        Args:
+            xs (torch.Tensor): The prefix token
+            x (int): The new token to append
+
+        Returns:
+            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+        """
+        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+        return torch.cat((xs, x))
+
+    def score_full(
+        self, hyp: Hypothesis, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.full_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.full_scorers`
+                and tensor score values of shape: `(self.n_vocab,)`,
+                and state dict that has string keys
+                and state values of `self.full_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.full_scorers.items():
+            if isinstance(d, AbsDecoder):
+                scores[k], spk_weigths, states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc, spk_enc, profile)
+            else:
+                scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc)
+        return scores, spk_weigths, states
+
+    def score_partial(
+        self, hyp: Hypothesis, ids: torch.Tensor, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.part_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            ids (torch.Tensor): 1D tensor of new partial tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.part_scorers`
+                and tensor score values of shape: `(len(ids),)`,
+                and state dict that has string keys
+                and state values of `self.part_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.part_scorers.items():
+            if isinstance(d, AbsDecoder):
+                scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc, spk_enc, profile)
+            else:
+                scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc)
+        return scores, states
+
+    def beam(
+        self, weighted_scores: torch.Tensor, ids: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute topk full token ids and partial token ids.
+
+        Args:
+            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+            Its shape is `(self.n_vocab,)`.
+            ids (torch.Tensor): The partial token ids to compute topk
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]:
+                The topk full token ids and partial token ids.
+                Their shapes are `(self.beam_size,)`
+
+        """
+        # no pre beam performed
+        if weighted_scores.size(0) == ids.size(0):
+            top_ids = weighted_scores.topk(self.beam_size)[1]
+            return top_ids, top_ids
+
+        # mask pruned in pre-beam not to select in topk
+        tmp = weighted_scores[ids]
+        weighted_scores[:] = -float("inf")
+        weighted_scores[ids] = tmp
+        top_ids = weighted_scores.topk(self.beam_size)[1]
+        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+        return top_ids, local_ids
+
+    @staticmethod
+    def merge_scores(
+        prev_scores: Dict[str, float],
+        next_full_scores: Dict[str, torch.Tensor],
+        full_idx: int,
+        next_part_scores: Dict[str, torch.Tensor],
+        part_idx: int,
+    ) -> Dict[str, torch.Tensor]:
+        """Merge scores for new hypothesis.
+
+        Args:
+            prev_scores (Dict[str, float]):
+                The previous hypothesis scores by `self.scorers`
+            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+            full_idx (int): The next token id for `next_full_scores`
+            next_part_scores (Dict[str, torch.Tensor]):
+                scores of partial tokens by `self.part_scorers`
+            part_idx (int): The new token id for `next_part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are scalar tensors by the scorers.
+
+        """
+        new_scores = dict()
+        for k, v in next_full_scores.items():
+            new_scores[k] = prev_scores[k] + v[full_idx]
+        for k, v in next_part_scores.items():
+            new_scores[k] = prev_scores[k] + v[part_idx]
+        return new_scores
+
+    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+        """Merge states for new hypothesis.
+
+        Args:
+            states: states of `self.full_scorers`
+            part_states: states of `self.part_scorers`
+            part_idx (int): The new token id for `part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are states of the scorers.
+
+        """
+        new_states = dict()
+        for k, v in states.items():
+            new_states[k] = v
+        for k, d in self.part_scorers.items():
+            new_states[k] = d.select_state(part_states[k], part_idx)
+        return new_states
+
+    def search(
+        self, running_hyps: List[Hypothesis], asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor
+    ) -> List[Hypothesis]:
+        """Search new tokens for running hypotheses and encoded speech x.
+
+        Args:
+            running_hyps (List[Hypothesis]): Running hypotheses on beam
+            x (torch.Tensor): Encoded speech feature (T, D)
+
+        Returns:
+            List[Hypotheses]: Best sorted hypotheses
+
+        """
+        # import ipdb;ipdb.set_trace()
+        best_hyps = []
+        part_ids = torch.arange(self.n_vocab, device=asr_enc.device)  # no pre-beam
+        for hyp in running_hyps:
+            # scoring
+            weighted_scores = torch.zeros(self.n_vocab, dtype=asr_enc.dtype, device=asr_enc.device)
+            scores, spk_weigths, states = self.score_full(hyp, asr_enc, spk_enc, profile)
+            for k in self.full_scorers:
+                weighted_scores += self.weights[k] * scores[k]
+            # partial scoring
+            if self.do_pre_beam:
+                pre_beam_scores = (
+                    weighted_scores
+                    if self.pre_beam_score_key == "full"
+                    else scores[self.pre_beam_score_key]
+                )
+                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+            part_scores, part_states = self.score_partial(hyp, part_ids, asr_enc, spk_enc, profile)
+            for k in self.part_scorers:
+                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+            # add previous hyp score
+            weighted_scores += hyp.score
+
+            # update hyps
+            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+                # will be (2 x beam at most)
+                best_hyps.append(
+                    Hypothesis(
+                        score=weighted_scores[j],
+                        yseq=self.append_token(hyp.yseq, j),
+                        scores=self.merge_scores(
+                            hyp.scores, scores, j, part_scores, part_j
+                        ),
+                        states=self.merge_states(states, part_states, part_j),
+                        spk_weigths=hyp.spk_weigths+[spk_weigths],
+                    )
+                )
+
+            # sort and prune 2 x beam -> beam
+            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+                : min(len(best_hyps), self.beam_size)
+            ]
+        return best_hyps
+
+    def forward(
+        self, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+    ) -> List[Hypothesis]:
+        """Perform beam search.
+
+        Args:
+            x (torch.Tensor): Encoded speech feature (T, D)
+            maxlenratio (float): Input length ratio to obtain max output length.
+                If maxlenratio=0.0 (default), it uses a end-detect function
+                to automatically find maximum hypothesis lengths
+            minlenratio (float): Input length ratio to obtain min output length.
+
+        Returns:
+            list[Hypothesis]: N-best decoding results
+
+        """
+        # import ipdb;ipdb.set_trace()
+        # set length bounds
+        if maxlenratio == 0:
+            maxlen = asr_enc.shape[0]
+        else:
+            maxlen = max(1, int(maxlenratio * asr_enc.size(0)))
+        minlen = int(minlenratio * asr_enc.size(0))
+        logging.info("decoder input length: " + str(asr_enc.shape[0]))
+        logging.info("max output length: " + str(maxlen))
+        logging.info("min output length: " + str(minlen))
+
+        # main loop of prefix search
+        running_hyps = self.init_hyp(asr_enc)
+        ended_hyps = []
+        for i in range(maxlen):
+            logging.debug("position " + str(i))
+            best = self.search(running_hyps, asr_enc, spk_enc, profile)
+            #import pdb;pdb.set_trace()
+            # post process of one iteration
+            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+            # end detection
+            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+                logging.info(f"end detected at {i}")
+                break
+            if len(running_hyps) == 0:
+                logging.info("no hypothesis. Finish decoding.")
+                break
+            else:
+                logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+        # check the number of hypotheses reaching to eos
+        if len(nbest_hyps) == 0:
+            logging.warning(
+                "there is no N-best results, perform recognition "
+                "again with smaller minlenratio."
+            )
+            return (
+                []
+                if minlenratio < 0.1
+                else self.forward(asr_enc, spk_enc, profile, maxlenratio, max(0.0, minlenratio - 0.1))
+            )
+
+        # report the best result
+        best = nbest_hyps[0]
+        for k, v in best.scores.items():
+            logging.info(
+                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+            )
+        logging.info(f"total log probability: {best.score:.2f}")
+        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+        if self.token_list is not None:
+            logging.info(
+                "best hypo: "
+                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
+                + "\n"
+            )
+        return nbest_hyps
+
+    def post_process(
+        self,
+        i: int,
+        maxlen: int,
+        maxlenratio: float,
+        running_hyps: List[Hypothesis],
+        ended_hyps: List[Hypothesis],
+    ) -> List[Hypothesis]:
+        """Perform post-processing of beam search iterations.
+
+        Args:
+            i (int): The length of hypothesis tokens.
+            maxlen (int): The maximum length of tokens in beam search.
+            maxlenratio (int): The maximum length ratio in beam search.
+            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+        Returns:
+            List[Hypothesis]: The new running hypotheses.
+
+        """
+        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+        if self.token_list is not None:
+            logging.debug(
+                "best hypo: "
+                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+            )
+        # add eos in the final loop to avoid that there are no ended hyps
+        if i == maxlen - 1:
+            logging.info("adding <eos> in the last position in the loop")
+            running_hyps = [
+                h._replace(yseq=self.append_token(h.yseq, self.eos))
+                for h in running_hyps
+            ]
+
+        # add ended hypotheses to a final list, and removed them from current hypotheses
+        # (this will be a problem, number of hyps < beam)
+        remained_hyps = []
+        for hyp in running_hyps:
+            if hyp.yseq[-1] == self.eos:
+                # e.g., Word LM needs to add final <eos> score
+                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+                    s = d.final_score(hyp.states[k])
+                    hyp.scores[k] += s
+                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+                ended_hyps.append(hyp)
+            else:
+                remained_hyps.append(hyp)
+        return remained_hyps
+
+
+def beam_search(
+    x: torch.Tensor,
+    sos: int,
+    eos: int,
+    beam_size: int,
+    vocab_size: int,
+    scorers: Dict[str, ScorerInterface],
+    weights: Dict[str, float],
+    token_list: List[str] = None,
+    maxlenratio: float = 0.0,
+    minlenratio: float = 0.0,
+    pre_beam_ratio: float = 1.5,
+    pre_beam_score_key: str = "full",
+) -> list:
+    """Perform beam search with scorers.
+
+    Args:
+        x (torch.Tensor): Encoded speech feature (T, D)
+        sos (int): Start of sequence id
+        eos (int): End of sequence id
+        beam_size (int): The number of hypotheses kept during search
+        vocab_size (int): The number of vocabulary
+        scorers (dict[str, ScorerInterface]): Dict of decoder modules
+            e.g., Decoder, CTCPrefixScorer, LM
+            The scorer will be ignored if it is `None`
+        weights (dict[str, float]): Dict of weights for each scorers
+            The scorer will be ignored if its weight is 0
+        token_list (list[str]): List of tokens for debug log
+        maxlenratio (float): Input length ratio to obtain max output length.
+            If maxlenratio=0.0 (default), it uses a end-detect function
+            to automatically find maximum hypothesis lengths
+        minlenratio (float): Input length ratio to obtain min output length.
+        pre_beam_score_key (str): key of scores to perform pre-beam search
+        pre_beam_ratio (float): beam size in the pre-beam search
+            will be `int(pre_beam_ratio * beam_size)`
+
+    Returns:
+        list: N-best decoding results
+
+    """
+    ret = BeamSearch(
+        scorers,
+        weights,
+        beam_size=beam_size,
+        vocab_size=vocab_size,
+        pre_beam_ratio=pre_beam_ratio,
+        pre_beam_score_key=pre_beam_score_key,
+        sos=sos,
+        eos=eos,
+        token_list=token_list,
+    ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
+    return [h.asdict() for h in ret]
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 3d2004c..f8c1009 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -445,6 +445,12 @@
             help='Perform on "collect stats" mode',
         )
         group.add_argument(
+            "--mc",
+            type=bool,
+            default=False,
+            help="MultiChannel input",
+        )
+        group.add_argument(
             "--write_collected_feats",
             type=str2bool,
             default=False,
@@ -635,8 +641,8 @@
         group.add_argument(
             "--init_param",
             type=str,
+            action="append",
             default=[],
-            nargs="*",
             help="Specify the file path used for initialization of parameters. "
                  "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
                  "where file_path is the model file path, "
@@ -662,7 +668,7 @@
             "--freeze_param",
             type=str,
             default=[],
-            nargs="*",
+            action="append",
             help="Freeze parameters",
         )
 
@@ -1153,10 +1159,10 @@
         elif args.distributed and args.simple_ddp:
             distributed_option.init_torch_distributed_pai(args)
             args.ngpu = dist.get_world_size()
-            if args.dataset_type == "small":
+            if args.dataset_type == "small" and args.ngpu > 0:
                 if args.batch_size is not None:
                     args.batch_size = args.batch_size * args.ngpu
-                if args.batch_bins is not None:
+                if args.batch_bins is not None and args.ngpu > 0:
                     args.batch_bins = args.batch_bins * args.ngpu
 
         # filter samples if wav.scp and text are mismatch
@@ -1316,6 +1322,7 @@
                     data_path_and_name_and_type=args.train_data_path_and_name_and_type,
                     key_file=train_key_file,
                     batch_size=args.batch_size,
+                    mc=args.mc,
                     dtype=args.train_dtype,
                     num_workers=args.num_workers,
                     allow_variable_data_keys=args.allow_variable_data_keys,
@@ -1327,6 +1334,7 @@
                     data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
                     key_file=valid_key_file,
                     batch_size=args.valid_batch_size,
+                    mc=args.mc,
                     dtype=args.train_dtype,
                     num_workers=args.num_workers,
                     allow_variable_data_keys=args.allow_variable_data_keys,
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
new file mode 100644
index 0000000..738ec52
--- /dev/null
+++ b/funasr/tasks/sa_asr.py
@@ -0,0 +1,623 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+import yaml
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
+from funasr.models.decoder.transformer_decoder import (
+    DynamicConvolution2DTransformerDecoder,  # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
+from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
+from funasr.models.decoder.transformer_decoder import (
+    LightweightConvolution2DTransformerDecoder,  # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import (
+    LightweightConvolutionTransformerDecoder,  # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.e2e_sa_asr import ESPnetASRModel
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
+from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34,ResNet34Diar
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+    HuggingFaceTransformersPostEncoder,  # noqa: H301
+)
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(
+        default=DefaultFrontend,
+        sliding_window=SlidingWindow,
+        s3prl=S3prlFrontend,
+        fused=FusedFrontends,
+        wav_frontend=WavFrontend,
+        multichannelfrontend=MultiChannelFrontend,
+    ),
+    type_check=AbsFrontend,
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(
+        specaug=SpecAug,
+        specaug_lfr=SpecAugLFR,
+    ),
+    type_check=AbsSpecAug,
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    type_check=AbsNormalize,
+    default=None,
+    optional=True,
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        asr=ESPnetASRModel,
+        uniasr=UniASR,
+        paraformer=Paraformer,
+        paraformer_bert=ParaformerBert,
+        bicif_paraformer=BiCifParaformer,
+        contextual_paraformer=ContextualParaformer,
+        mfcca=MFCCA,
+        timestamp_prediction=TimestampPredictor,
+    ),
+    type_check=AbsESPnetModel,
+    default="asr",
+)
+preencoder_choices = ClassChoices(
+    name="preencoder",
+    classes=dict(
+        sinc=LightweightSincConvs,
+        linear=LinearProjection,
+    ),
+    type_check=AbsPreEncoder,
+    default=None,
+    optional=True,
+)
+asr_encoder_choices = ClassChoices(
+    "asr_encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+        mfcca_enc=MFCCAEncoder,
+    ),
+    type_check=AbsEncoder,
+    default="rnn",
+)
+
+spk_encoder_choices = ClassChoices(
+    "spk_encoder",
+    classes=dict(
+        resnet34_diar=ResNet34Diar,
+    ),
+    default="resnet34_diar",
+)
+
+encoder_choices2 = ClassChoices(
+    "encoder2",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+    ),
+    type_check=AbsEncoder,
+    default="rnn",
+)
+postencoder_choices = ClassChoices(
+    name="postencoder",
+    classes=dict(
+        hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+    ),
+    type_check=AbsPostEncoder,
+    default=None,
+    optional=True,
+)
+decoder_choices = ClassChoices(
+    "decoder",
+    classes=dict(
+        transformer=TransformerDecoder,
+        lightweight_conv=LightweightConvolutionTransformerDecoder,
+        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+        dynamic_conv=DynamicConvolutionTransformerDecoder,
+        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+        rnn=RNNDecoder,
+        fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+        paraformer_decoder_sanm=ParaformerSANMDecoder,
+        paraformer_decoder_san=ParaformerDecoderSAN,
+        contextual_paraformer_decoder=ContextualParaformerDecoder,
+        sa_decoder=SAAsrTransformerDecoder,
+    ),
+    type_check=AbsDecoder,
+    default="sa_decoder",
+)
+decoder_choices2 = ClassChoices(
+    "decoder2",
+    classes=dict(
+        transformer=TransformerDecoder,
+        lightweight_conv=LightweightConvolutionTransformerDecoder,
+        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+        dynamic_conv=DynamicConvolutionTransformerDecoder,
+        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+        rnn=RNNDecoder,
+        fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+        paraformer_decoder_sanm=ParaformerSANMDecoder,
+    ),
+    type_check=AbsDecoder,
+    default="rnn",
+)
+predictor_choices = ClassChoices(
+    name="predictor",
+    classes=dict(
+        cif_predictor=CifPredictor,
+        ctc_predictor=None,
+        cif_predictor_v2=CifPredictorV2,
+        cif_predictor_v3=CifPredictorV3,
+    ),
+    type_check=None,
+    default="cif_predictor",
+    optional=True,
+)
+predictor_choices2 = ClassChoices(
+    name="predictor2",
+    classes=dict(
+        cif_predictor=CifPredictor,
+        ctc_predictor=None,
+        cif_predictor_v2=CifPredictorV2,
+    ),
+    type_check=None,
+    default="cif_predictor",
+    optional=True,
+)
+stride_conv_choices = ClassChoices(
+    name="stride_conv",
+    classes=dict(
+        stride_conv1d=Conv1dSubsampling
+    ),
+    type_check=None,
+    default="stride_conv1d",
+    optional=True,
+)
+
+
+class ASRTask(AbsTask):
+    # If you need more than one optimizers, change this value
+    num_optimizers: int = 1
+
+    # Add variable objects configurations
+    class_choices_list = [
+        # --frontend and --frontend_conf
+        frontend_choices,
+        # --specaug and --specaug_conf
+        specaug_choices,
+        # --normalize and --normalize_conf
+        normalize_choices,
+        # --model and --model_conf
+        model_choices,
+        # --preencoder and --preencoder_conf
+        preencoder_choices,
+        # --asr_encoder and --asr_encoder_conf
+        asr_encoder_choices,
+        # --spk_encoder and --spk_encoder_conf
+        spk_encoder_choices,
+        # --postencoder and --postencoder_conf
+        postencoder_choices,
+        # --decoder and --decoder_conf
+        decoder_choices,
+    ]
+
+    # If you need to modify train() or eval() procedures, change Trainer class here
+    trainer = Trainer
+
+    @classmethod
+    def add_task_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(description="Task related")
+
+        # NOTE(kamo): add_arguments(..., required=True) can't be used
+        # to provide --print_config mode. Instead of it, do as
+        # required = parser.get_default("required")
+        # required += ["token_list"]
+
+        group.add_argument(
+            "--token_list",
+            type=str_or_none,
+            default=None,
+            help="A text mapping int-id to token",
+        )
+        group.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
+        group.add_argument(
+            "--max_spk_num",
+            type=int_or_none,
+            default=None,
+            help="A text mapping int-id to token",
+        )
+        group.add_argument(
+            "--seg_dict_file",
+            type=str,
+            default=None,
+            help="seg_dict_file for text processing",
+        )
+        group.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+
+        group.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+
+        group.add_argument(
+            "--ctc_conf",
+            action=NestedDictAction,
+            default=get_default_kwargs(CTC),
+            help="The keyword arguments for CTC class.",
+        )
+        group.add_argument(
+            "--joint_net_conf",
+            action=NestedDictAction,
+            default=None,
+            help="The keyword arguments for joint network class.",
+        )
+
+        group = parser.add_argument_group(description="Preprocess related")
+        group.add_argument(
+            "--use_preprocessor",
+            type=str2bool,
+            default=True,
+            help="Apply preprocessing to data or not",
+        )
+        group.add_argument(
+            "--token_type",
+            type=str,
+            default="bpe",
+            choices=["bpe", "char", "word", "phn"],
+            help="The text will be tokenized " "in the specified level token",
+        )
+        group.add_argument(
+            "--bpemodel",
+            type=str_or_none,
+            default=None,
+            help="The model file of sentencepiece",
+        )
+        parser.add_argument(
+            "--non_linguistic_symbols",
+            type=str_or_none,
+            default=None,
+            help="non_linguistic_symbols file path",
+        )
+        parser.add_argument(
+            "--cleaner",
+            type=str_or_none,
+            choices=[None, "tacotron", "jaconv", "vietnamese"],
+            default=None,
+            help="Apply text cleaning",
+        )
+        parser.add_argument(
+            "--g2p",
+            type=str_or_none,
+            choices=g2p_choices,
+            default=None,
+            help="Specify g2p method if --token_type=phn",
+        )
+        parser.add_argument(
+            "--speech_volume_normalize",
+            type=float_or_none,
+            default=None,
+            help="Scale the maximum amplitude to the given value.",
+        )
+        parser.add_argument(
+            "--rir_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of rir scp file.",
+        )
+        parser.add_argument(
+            "--rir_apply_prob",
+            type=float,
+            default=1.0,
+            help="THe probability for applying RIR convolution.",
+        )
+        parser.add_argument(
+            "--cmvn_file",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+        parser.add_argument(
+            "--noise_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+        parser.add_argument(
+            "--noise_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability applying Noise adding.",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of noise decibel level.",
+        )
+
+        for class_choices in cls.class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(group)
+
+    @classmethod
+    def build_collate_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Callable[
+        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+        Tuple[List[str], Dict[str, torch.Tensor]],
+    ]:
+        assert check_argument_types()
+        # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
+        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+    @classmethod
+    def build_preprocess_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+        assert check_argument_types()
+        if args.use_preprocessor:
+            retval = CommonPreprocessor(
+                train=train,
+                token_type=args.token_type,
+                token_list=args.token_list,
+                bpemodel=args.bpemodel,
+                non_linguistic_symbols=args.non_linguistic_symbols,
+                text_cleaner=args.cleaner,
+                g2p_type=args.g2p,
+                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+                seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+                # NOTE(kamo): Check attribute existence for backward compatibility
+                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+                rir_apply_prob=args.rir_apply_prob
+                if hasattr(args, "rir_apply_prob")
+                else 1.0,
+                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+                noise_apply_prob=args.noise_apply_prob
+                if hasattr(args, "noise_apply_prob")
+                else 1.0,
+                noise_db_range=args.noise_db_range
+                if hasattr(args, "noise_db_range")
+                else "13_15",
+                speech_volume_normalize=args.speech_volume_normalize
+                if hasattr(args, "rir_scp")
+                else None,
+            )
+        else:
+            retval = None
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def required_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        if not inference:
+            retval = ("speech", "text")
+        else:
+            # Recognition mode
+            retval = ("speech",)
+        return retval
+
+    @classmethod
+    def optional_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        retval = ()
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace):
+        assert check_argument_types()
+        if isinstance(args.token_list, str):
+            with open(args.token_list, encoding="utf-8") as f:
+                token_list = [line.rstrip() for line in f]
+
+            # Overwriting token_list to keep it as "portable".
+            args.token_list = list(token_list)
+        elif isinstance(args.token_list, (tuple, list)):
+            token_list = list(args.token_list)
+        else:
+            raise RuntimeError("token_list must be str or list")
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            if args.frontend == 'wav_frontend':
+                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+            else:
+                frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 4. Pre-encoder input block
+        # NOTE(kan-bayashi): Use getattr to keep the compatibility
+        if getattr(args, "preencoder", None) is not None:
+            preencoder_class = preencoder_choices.get_class(args.preencoder)
+            preencoder = preencoder_class(**args.preencoder_conf)
+            input_size = preencoder.output_size()
+        else:
+            preencoder = None
+
+        # 5. Encoder
+        asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+        asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+        spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+        spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+
+        # 6. Post-encoder block
+        # NOTE(kan-bayashi): Use getattr to keep the compatibility
+        asr_encoder_output_size = asr_encoder.output_size()
+        if getattr(args, "postencoder", None) is not None:
+            postencoder_class = postencoder_choices.get_class(args.postencoder)
+            postencoder = postencoder_class(
+                input_size=asr_encoder_output_size, **args.postencoder_conf
+            )
+            asr_encoder_output_size = postencoder.output_size()
+        else:
+            postencoder = None
+
+        # 7. Decoder
+        decoder_class = decoder_choices.get_class(args.decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=asr_encoder_output_size,
+            **args.decoder_conf,
+        )
+
+        # 8. CTC
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=asr_encoder_output_size, **args.ctc_conf
+        )
+
+        max_spk_num=int(args.max_spk_num)
+
+        # import ipdb;ipdb.set_trace()
+        # 9. Build model
+        try:
+            model_class = model_choices.get_class(args.model)
+        except AttributeError:
+            model_class = model_choices.get_class("asr")
+        model = model_class(
+            vocab_size=vocab_size,
+            max_spk_num=max_spk_num,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            asr_encoder=asr_encoder,
+            spk_encoder=spk_encoder,
+            postencoder=postencoder,
+            decoder=decoder,
+            ctc=ctc,
+            token_list=token_list,
+            **args.model_conf,
+        )
+
+        # 10. Initialize
+        if args.init is not None:
+            initialize(model, args.init)
+
+        assert check_return_type(model)
+        return model
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index b607e1d..014a79f 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -106,18 +106,17 @@
         if num in abbr_begin:
             if time_stamp is not None:
                 begin = time_stamp[ts_nums[num]][0]
-            abbr_word = words[num].upper()
+            word_lists.append(words[num].upper())
             num += 1
             while num < words_size:
                 if num in abbr_end:
-                    abbr_word += words[num].upper()
+                    word_lists.append(words[num].upper())
                     last_num = num
                     break
                 else:
                     if words[num].encode('utf-8').isalpha():
-                        abbr_word += words[num].upper()
+                        word_lists.append(words[num].upper())
                 num += 1
-            word_lists.append(abbr_word)
             if time_stamp is not None:
                 end = time_stamp[ts_nums[num]][1]
                 ts_lists.append([begin, end])
diff --git a/setup.py b/setup.py
index e837637..ea55606 100644
--- a/setup.py
+++ b/setup.py
@@ -13,7 +13,7 @@
     "install": [
         "setuptools>=38.5.1",
         # "configargparse>=1.2.1",
-        "typeguard<=2.13.3",
+        "typeguard==2.13.3",
         "humanfriendly",
         "scipy>=1.4.1",
         # "filelock",
@@ -42,7 +42,10 @@
         "oss2",
         # "kaldi-native-fbank",
         # timestamp
-        "edit-distance"
+        "edit-distance",
+        # textgrid
+        "textgrid",
+        "protobuf==3.20.0",
     ],
     # train: The modules invoked when training only.
     "train": [

--
Gitblit v1.9.1