From a73123bcfc14370b74b17084bc124f00c48613e4 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期六, 06 五月 2023 16:17:48 +0800
Subject: [PATCH] add speaker-attributed ASR task for alimeeting
---
egs/alimeeting/sa-asr/utils/validate_data_dir.sh | 404 ++
egs/alimeeting/sa-asr/asr_local.sh | 1562 +++++++++
egs/alimeeting/sa-asr/local/process_text_spk_merge.py | 55
egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py | 243 +
egs/alimeeting/sa-asr/utils/data/split_data.sh | 160
egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh | 143
setup.py | 7
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh | 129
egs/alimeeting/sa-asr/utils/combine_data.sh | 146
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh | 162
egs/alimeeting/sa-asr/path.sh | 6
egs/alimeeting/sa-asr/utils/validate_text.pl | 136
funasr/tasks/sa_asr.py | 623 +++
egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl | 38
funasr/losses/nll_loss.py | 47
egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl | 27
egs/alimeeting/sa-asr/utils/split_scp.pl | 246 +
egs/alimeeting/sa-asr/utils/parse_options.sh | 97
funasr/bin/sa_asr_train.py | 55
egs/alimeeting/sa-asr/local/process_text_id.py | 24
egs/alimeeting/sa-asr/local/gen_oracle_embedding.py | 70
egs/alimeeting/sa-asr/local/compute_wer.py | 157
funasr/models/pooling/statistic_pooling.py | 4
egs/alimeeting/sa-asr/local/proce_text.py | 32
egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py | 167 +
egs/alimeeting/sa-asr/utils/copy_data_dir.sh | 145
egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py | 86
funasr/tasks/abs_task.py | 16
egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh | 29
egs/alimeeting/sa-asr/utils/filter_scp.pl | 87
funasr/utils/postprocess_utils.py | 7
funasr/bin/asr_inference.py | 34
egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py | 158
funasr/fileio/sound_scp.py | 6
funasr/bin/asr_train.py | 15
funasr/models/e2e_sa_asr.py | 521 +++
egs/alimeeting/sa-asr/asr_local_infer.sh | 590 +++
funasr/bin/sa_asr_inference.py | 674 ++++
egs/alimeeting/sa-asr/run_m2met_2023_infer.sh | 50
egs/alimeeting/sa-asr/local/download_xvector_model.py | 6
egs/alimeeting/sa-asr/local/compute_cpcer.py | 91
egs/alimeeting/sa-asr/local/text_normalize.pl | 38
funasr/modules/attention.py | 37
egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py | 22
egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py | 235 +
egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py | 68
funasr/models/decoder/decoder_layer_sa_asr.py | 169 +
egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py | 127
egs/alimeeting/sa-asr/utils/fix_data_dir.sh | 215 +
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml | 88
egs/alimeeting/sa-asr/run_m2met_2023.sh | 51
egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh | 135
funasr/modules/beam_search/beam_search_sa_asr.py | 525 +++
egs/alimeeting/sa-asr/local/text_format.pl | 14
egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml | 29
egs/alimeeting/sa-asr/utils/apply_map.pl | 97
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml | 116
egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py | 59
egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh | 142
egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml | 6
funasr/models/decoder/transformer_decoder_sa_asr.py | 291 +
funasr/models/frontend/default.py | 11
egs/alimeeting/sa-asr/pyscripts/utils/print_args.py | 45
funasr/bin/asr_inference_launch.py | 5
egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh | 116
65 files changed, 9,859 insertions(+), 37 deletions(-)
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
new file mode 100755
index 0000000..c0359eb
--- /dev/null
+++ b/egs/alimeeting/sa-asr/asr_local.sh
@@ -0,0 +1,1562 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+min() {
+ local a b
+ a=$1
+ for b in "$@"; do
+ if [ "${b}" -le "${a}" ]; then
+ a="${b}"
+ fi
+ done
+ echo "${a}"
+}
+SECONDS=0
+
+# General configuration
+stage=1 # Processes starts from the specified stage.
+stop_stage=10000 # Processes is stopped at the specified stage.
+skip_data_prep=false # Skip data preparation stages.
+skip_train=false # Skip training stages.
+skip_eval=false # Skip decoding and evaluation stages.
+skip_upload=true # Skip packing and uploading stages.
+ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
+num_nodes=1 # The number of nodes.
+nj=16 # The number of parallel jobs.
+inference_nj=16 # The number of parallel jobs in decoding.
+gpu_inference=false # Whether to perform gpu decoding.
+njob_infer=4
+dumpdir=dump2 # Directory to dump features.
+expdir=exp # Directory to save experiments.
+python=python3 # Specify python to execute espnet commands.
+device=0
+
+# Data preparation related
+local_data_opts= # The options given to local/data.sh.
+
+# Speed perturbation related
+speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
+
+# Feature extraction related
+feats_type=raw # Feature type (raw or fbank_pitch).
+audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
+fs=16000 # Sampling rate.
+min_wav_duration=0.1 # Minimum duration in second.
+max_wav_duration=20 # Maximum duration in second.
+
+# Tokenization related
+token_type=bpe # Tokenization type (char or bpe).
+nbpe=30 # The number of BPE vocabulary.
+bpemode=unigram # Mode of BPE (unigram or bpe).
+oov="<unk>" # Out of vocabulary symbol.
+blank="<blank>" # CTC blank symbol
+sos_eos="<sos/eos>" # sos and eos symbole
+bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
+bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
+bpe_char_cover=1.0 # character coverage when modeling BPE
+
+# Language model related
+use_lm=true # Use language model for ASR decoding.
+lm_tag= # Suffix to the result dir for language model training.
+lm_exp= # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored.
+lm_stats_dir= # Specify the direcotry path for LM statistics.
+lm_config= # Config for language model training.
+lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in lm config.
+use_word_lm=false # Whether to use word language model.
+num_splits_lm=1 # Number of splitting for lm corpus.
+# shellcheck disable=SC2034
+word_vocab_size=10000 # Size of word vocabulary.
+
+# ASR model related
+asr_tag= # Suffix to the result dir for asr model training.
+asr_exp= # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored.
+sa_asr_exp=
+asr_stats_dir= # Specify the direcotry path for ASR statistics.
+asr_config= # Config for asr model training.
+sa_asr_config=
+asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in asr config.
+feats_normalize=global_mvn # Normalizaton layer type.
+num_splits_asr=1 # Number of splitting for lm corpus.
+
+# Decoding related
+inference_tag= # Suffix to the result dir for decoding.
+inference_config= # Config for decoding.
+inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
+ # Note that it will overwrite args in inference config.
+sa_asr_inference_tag=
+sa_asr_inference_args=
+
+inference_lm=valid.loss.ave.pb # Language modle path for decoding.
+inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
+ # e.g.
+ # inference_asr_model=train.loss.best.pth
+ # inference_asr_model=3epoch.pth
+ # inference_asr_model=valid.acc.best.pth
+ # inference_asr_model=valid.loss.ave.pth
+inference_sa_asr_model=valid.acc_spk.ave.pb
+download_model= # Download a model from Model Zoo and use it for decoding.
+
+# [Task dependent] Set the datadir name created by local/data.sh
+train_set= # Name of training set.
+valid_set= # Name of validation set used for monitoring/tuning network training.
+test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
+bpe_train_text= # Text file path of bpe training set.
+lm_train_text= # Text file path of language model training set.
+lm_dev_text= # Text file path of language model development set.
+lm_test_text= # Text file path of language model evaluation set.
+nlsyms_txt=none # Non-linguistic symbol list if existing.
+cleaner=none # Text cleaner.
+g2p=none # g2p method (needed if token_type=phn).
+lang=zh # The language type of corpus.
+score_opts= # The options given to sclite scoring
+local_score_opts= # The options given to local/score.sh.
+
+
+help_message=$(cat << EOF
+Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
+
+Options:
+ # General configuration
+ --stage # Processes starts from the specified stage (default="${stage}").
+ --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
+ --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
+ --skip_train # Skip training stages (default="${skip_train}").
+ --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
+ --skip_upload # Skip packing and uploading stages (default="${skip_upload}").
+ --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
+ --num_nodes # The number of nodes (default="${num_nodes}").
+ --nj # The number of parallel jobs (default="${nj}").
+ --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
+ --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
+ --dumpdir # Directory to dump features (default="${dumpdir}").
+ --expdir # Directory to save experiments (default="${expdir}").
+ --python # Specify python to execute espnet commands (default="${python}").
+ --device # Which GPUs are use for local training (defalut="${device}").
+
+ # Data preparation related
+ --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
+
+ # Speed perturbation related
+ --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
+
+ # Feature extraction related
+ --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
+ --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
+ --fs # Sampling rate (default="${fs}").
+ --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
+ --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
+
+ # Tokenization related
+ --token_type # Tokenization type (char or bpe, default="${token_type}").
+ --nbpe # The number of BPE vocabulary (default="${nbpe}").
+ --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
+ --oov # Out of vocabulary symbol (default="${oov}").
+ --blank # CTC blank symbol (default="${blank}").
+ --sos_eos # sos and eos symbole (default="${sos_eos}").
+ --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
+ --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
+ --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
+
+ # Language model related
+ --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
+ --lm_exp # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored (default="${lm_exp}").
+ --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
+ --lm_config # Config for language model training (default="${lm_config}").
+ --lm_args # Arguments for language model training (default="${lm_args}").
+ # e.g., --lm_args "--max_epoch 10"
+ # Note that it will overwrite args in lm config.
+ --use_word_lm # Whether to use word language model (default="${use_word_lm}").
+ --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
+ --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
+
+ # ASR model related
+ --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
+ --asr_exp # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored (default="${asr_exp}").
+ --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
+ --asr_config # Config for asr model training (default="${asr_config}").
+ --asr_args # Arguments for asr model training (default="${asr_args}").
+ # e.g., --asr_args "--max_epoch 10"
+ # Note that it will overwrite args in asr config.
+ --feats_normalize # Normalizaton layer type (default="${feats_normalize}").
+ --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
+
+ # Decoding related
+ --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
+ --inference_config # Config for decoding (default="${inference_config}").
+ --inference_args # Arguments for decoding (default="${inference_args}").
+ # e.g., --inference_args "--lm_weight 0.1"
+ # Note that it will overwrite args in inference config.
+ --inference_lm # Language modle path for decoding (default="${inference_lm}").
+ --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
+ --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
+
+ # [Task dependent] Set the datadir name created by local/data.sh
+ --train_set # Name of training set (required).
+ --valid_set # Name of validation set used for monitoring/tuning network training (required).
+ --test_sets # Names of test sets.
+ # Multiple items (e.g., both dev and eval sets) can be specified (required).
+ --bpe_train_text # Text file path of bpe training set.
+ --lm_train_text # Text file path of language model training set.
+ --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
+ --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
+ --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
+ --cleaner # Text cleaner (default="${cleaner}").
+ --g2p # g2p method (default="${g2p}").
+ --lang # The language type of corpus (default=${lang}).
+ --score_opts # The options given to sclite scoring (default="{score_opts}").
+ --local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
+EOF
+)
+
+log "$0 $*"
+# Save command line args for logging (they will be lost after utils/parse_options.sh)
+run_args=$(python -m funasr.utils.cli_utils $0 "$@")
+. utils/parse_options.sh
+
+if [ $# -ne 0 ]; then
+ log "${help_message}"
+ log "Error: No positional arguments are required."
+ exit 2
+fi
+
+. ./path.sh
+
+
+# Check required arguments
+[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
+[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
+[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
+
+# Check feature type
+if [ "${feats_type}" = raw ]; then
+ data_feats=${dumpdir}/raw
+elif [ "${feats_type}" = fbank_pitch ]; then
+ data_feats=${dumpdir}/fbank_pitch
+elif [ "${feats_type}" = fbank ]; then
+ data_feats=${dumpdir}/fbank
+elif [ "${feats_type}" == extracted ]; then
+ data_feats=${dumpdir}/extracted
+else
+ log "${help_message}"
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+fi
+
+# Use the same text as ASR for bpe training if not specified.
+[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
+# Use the text of the 1st evaldir if lm_test is not specified
+[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
+
+# Check tokenization type
+if [ "${lang}" != noinfo ]; then
+ token_listdir=data/${lang}_token_list
+else
+ token_listdir=data/token_list
+fi
+bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
+bpeprefix="${bpedir}"/bpe
+bpemodel="${bpeprefix}".model
+bpetoken_list="${bpedir}"/tokens.txt
+chartoken_list="${token_listdir}"/char/tokens.txt
+# NOTE: keep for future development.
+# shellcheck disable=SC2034
+wordtoken_list="${token_listdir}"/word/tokens.txt
+
+if [ "${token_type}" = bpe ]; then
+ token_list="${bpetoken_list}"
+elif [ "${token_type}" = char ]; then
+ token_list="${chartoken_list}"
+ bpemodel=none
+elif [ "${token_type}" = word ]; then
+ token_list="${wordtoken_list}"
+ bpemodel=none
+else
+ log "Error: not supported --token_type '${token_type}'"
+ exit 2
+fi
+if ${use_word_lm}; then
+ log "Error: Word LM is not supported yet"
+ exit 2
+
+ lm_token_list="${wordtoken_list}"
+ lm_token_type=word
+else
+ lm_token_list="${token_list}"
+ lm_token_type="${token_type}"
+fi
+
+
+# Set tag for naming of model directory
+if [ -z "${asr_tag}" ]; then
+ if [ -n "${asr_config}" ]; then
+ asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
+ else
+ asr_tag="train_${feats_type}"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ asr_tag+="_${lang}_${token_type}"
+ else
+ asr_tag+="_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${asr_args}" ]; then
+ asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_tag+="_sp"
+ fi
+fi
+if [ -z "${lm_tag}" ]; then
+ if [ -n "${lm_config}" ]; then
+ lm_tag="$(basename "${lm_config}" .yaml)"
+ else
+ lm_tag="train"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ lm_tag+="_${lang}_${lm_token_type}"
+ else
+ lm_tag+="_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${lm_args}" ]; then
+ lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+fi
+
+# The directory used for collect-stats mode
+if [ -z "${asr_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
+ else
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_stats_dir+="${nbpe}"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_stats_dir+="_sp"
+ fi
+fi
+if [ -z "${lm_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
+ else
+ lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_stats_dir+="${nbpe}"
+ fi
+fi
+# The directory used for training commands
+if [ -z "${asr_exp}" ]; then
+ asr_exp="${expdir}/asr_${asr_tag}"
+fi
+if [ -z "${lm_exp}" ]; then
+ lm_exp="${expdir}/lm_${lm_tag}"
+fi
+
+
+if [ -z "${inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ inference_tag=inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${inference_args}" ]; then
+ inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+if [ -z "${sa_asr_inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ sa_asr_inference_tag=sa_asr_inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${sa_asr_inference_args}" ]; then
+ sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+train_cmd="run.pl"
+cuda_cmd="run.pl"
+decode_cmd="run.pl"
+
+# ========================== Main stages start from here. ==========================
+
+if ! "${skip_data_prep}"; then
+
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ log "Stage 1: Data preparation for data/${train_set}, data/${valid_set}, etc."
+
+ ./local/alimeeting_data_prep.sh --tgt Test
+ ./local/alimeeting_data_prep.sh --tgt Eval
+ ./local/alimeeting_data_prep.sh --tgt Train
+ fi
+
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ if [ -n "${speed_perturb_factors}" ]; then
+ log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
+ for factor in ${speed_perturb_factors}; do
+ if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
+ scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
+ _dirs+="data/${train_set}_sp${factor} "
+ else
+ # If speed factor is 1, same as the original
+ _dirs+="data/${train_set} "
+ fi
+ done
+ utils/combine_data.sh "data/${train_set}_sp" ${_dirs}
+ else
+ log "Skip stage 2: Speed perturbation"
+ fi
+ fi
+
+ if [ -n "${speed_perturb_factors}" ]; then
+ train_set="${train_set}_sp"
+ fi
+
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ if [ "${feats_type}" = raw ]; then
+ log "Stage 3: Format wav.scp: data/ -> ${data_feats}"
+
+ # ====== Recreating "wav.scp" ======
+ # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
+ # shouldn't be used in training process.
+ # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
+ # and it can also change the audio-format and sampling rate.
+ # If nothing is need, then format_wav_scp.sh does nothing:
+ # i.e. the input file format and rate is same as the output.
+
+ for dset in "${train_set}" "${valid_set}" "${test_sets}" ; do
+ if [ "${dset}" = "${train_set}" ] || [ "${dset}" = "${valid_set}" ]; then
+ _suf="/org"
+ else
+ if [ "${dset}" = "${test_sets}" ] && [ "${test_sets}" = "Test_Ali_far" ]; then
+ _suf="/org"
+ else
+ _suf=""
+ fi
+ fi
+ utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+
+ cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
+
+ rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
+ _opts=
+ if [ -e data/"${dset}"/segments ]; then
+ # "segments" is used for splitting wav files which are written in "wav".scp
+ # into utterances. The file format of segments:
+ # <segment_id> <record_id> <start_time> <end_time>
+ # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
+ # Where the time is written in seconds.
+ _opts+="--segments data/${dset}/segments "
+ fi
+ # shellcheck disable=SC2086
+ scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
+ "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
+
+ echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
+ done
+
+ else
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+ fi
+ fi
+
+
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ log "Stage 4: Remove long/short data: ${data_feats}/org -> ${data_feats}"
+
+ # NOTE(kamo): Not applying to test_sets to keep original data
+ if [ "${test_sets}" = "Test_Ali_far" ]; then
+ rm_dset="${train_set} ${valid_set} ${test_sets}"
+ else
+ rm_dset="${train_set} ${valid_set}"
+ fi
+
+ for dset in $rm_dset; do
+
+ # Copy data dir
+ utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
+ cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
+
+ # Remove short utterances
+ _feats_type="$(<${data_feats}/${dset}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))")
+ _min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))")
+ _max_length=$(python3 -c "print(int(${max_wav_duration} * ${_fs}))")
+
+ # utt2num_samples is created by format_wav_scp.sh
+ <"${data_feats}/org/${dset}/utt2num_samples" \
+ awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
+ '{ if ($2 > min_length && $2 < max_length ) print $0; }' \
+ >"${data_feats}/${dset}/utt2num_samples"
+ <"${data_feats}/org/${dset}/wav.scp" \
+ utils/filter_scp.pl "${data_feats}/${dset}/utt2num_samples" \
+ >"${data_feats}/${dset}/wav.scp"
+ else
+ # Get frame shift in ms from conf/fbank.conf
+ _frame_shift=
+ if [ -f conf/fbank.conf ] && [ "$(<conf/fbank.conf grep -c frame-shift)" -gt 0 ]; then
+ # Assume using conf/fbank.conf for feature extraction
+ _frame_shift="$(<conf/fbank.conf grep frame-shift | sed -e 's/[-a-z =]*\([0-9]*\)/\1/g')"
+ fi
+ if [ -z "${_frame_shift}" ]; then
+ # If not existing, use the default number in Kaldi (=10ms).
+ # If you are using different number, you have to change the following value manually.
+ _frame_shift=10
+ fi
+
+ _min_length=$(python3 -c "print(int(${min_wav_duration} / ${_frame_shift} * 1000))")
+ _max_length=$(python3 -c "print(int(${max_wav_duration} / ${_frame_shift} * 1000))")
+
+ cp "${data_feats}/org/${dset}/feats_dim" "${data_feats}/${dset}/feats_dim"
+ <"${data_feats}/org/${dset}/feats_shape" awk -F, ' { print $1 } ' \
+ | awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
+ '{ if ($2 > min_length && $2 < max_length) print $0; }' \
+ >"${data_feats}/${dset}/feats_shape"
+ <"${data_feats}/org/${dset}/feats.scp" \
+ utils/filter_scp.pl "${data_feats}/${dset}/feats_shape" \
+ >"${data_feats}/${dset}/feats.scp"
+ fi
+
+ # Remove empty text
+ <"${data_feats}/org/${dset}/text" \
+ awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
+
+ # fix_data_dir.sh leaves only utts which exist in all files
+ utils/fix_data_dir.sh "${data_feats}/${dset}"
+
+ # generate uttid
+ cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
+ # filter utt2spk_all_fifo
+ python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
+ done
+
+ # shellcheck disable=SC2002
+ cat ${lm_train_text} | awk ' { if( NF != 1 ) print $0; } ' > "${data_feats}/lm_train.txt"
+ fi
+
+
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ log "Stage 5: Dictionary Preparation"
+ mkdir -p data/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ num_token=$(cat ${token_list} | wc -l)
+ echo "<unk>" >> ${token_list}
+ vocab_size=$(cat ${token_list} | wc -l)
+ fi
+
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ log "Stage 6: Generate speaker settings"
+ mkdir -p "profile_log"
+ for dset in "${train_set}" "${valid_set}" "${test_sets}"; do
+ # generate text_id spk2id
+ python local/process_sot_fifo_textchar2spk.py --path ${data_feats}/${dset}
+ log "Successfully generate ${data_feats}/${dset}/text_id ${data_feats}/${dset}/spk2id"
+ # generate text_id_train for sot
+ python local/process_text_id.py ${data_feats}/${dset}
+ log "Successfully generate ${data_feats}/${dset}/text_id_train"
+ # generate oracle_embedding from single-speaker audio segment
+ python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
+ log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
+ # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
+ if [ "${dset}" = "${train_set}" ]; then
+ python local/gen_oracle_profile_padding.py ${data_feats}/${dset}
+ log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_padding.scp)"
+ else
+ python local/gen_oracle_profile_nopadding.py ${data_feats}/${dset}
+ log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_nopadding.scp)"
+ fi
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
+ python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+ log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
+ fi
+
+ done
+ fi
+
+else
+ log "Skip the stages for data preparation"
+fi
+
+
+# ========================== Data preparation is done here. ==========================
+
+
+if ! "${skip_train}"; then
+ if "${use_lm}"; then
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ log "Stage 7: LM collect stats: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
+
+ _opts=
+ if [ -n "${lm_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.lm_train --print_config --optim adam
+ _opts+="--config ${lm_config} "
+ fi
+
+ # 1. Split the key file
+ _logdir="${lm_stats_dir}/logdir"
+ mkdir -p "${_logdir}"
+ # Get the minimum number among ${nj} and the number lines of input files
+ _nj=$(min "${nj}" "$(<${data_feats}/lm_train.txt wc -l)" "$(<${lm_dev_text} wc -l)")
+
+ key_file="${data_feats}/lm_train.txt"
+ split_scps=""
+ for n in $(seq ${_nj}); do
+ split_scps+=" ${_logdir}/train.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ key_file="${lm_dev_text}"
+ split_scps=""
+ for n in $(seq ${_nj}); do
+ split_scps+=" ${_logdir}/dev.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Generate run.sh
+ log "Generate '${lm_stats_dir}/run.sh'. You can resume the process from stage 6 using this script"
+ mkdir -p "${lm_stats_dir}"; echo "${run_args} --stage 6 \"\$@\"; exit \$?" > "${lm_stats_dir}/run.sh"; chmod +x "${lm_stats_dir}/run.sh"
+
+ # 3. Submit jobs
+ log "LM collect-stats started... log: '${_logdir}/stats.*.log'"
+ # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
+ # but it's used only for deciding the sample ids.
+ # shellcheck disable=SC2086
+ ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
+ ${python} -m funasr.bin.lm_train \
+ --collect_stats true \
+ --use_preprocessor true \
+ --bpemodel "${bpemodel}" \
+ --token_type "${lm_token_type}"\
+ --token_list "${lm_token_list}" \
+ --non_linguistic_symbols "${nlsyms_txt}" \
+ --cleaner "${cleaner}" \
+ --g2p "${g2p}" \
+ --train_data_path_and_name_and_type "${data_feats}/lm_train.txt,text,text" \
+ --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
+ --train_shape_file "${_logdir}/train.JOB.scp" \
+ --valid_shape_file "${_logdir}/dev.JOB.scp" \
+ --output_dir "${_logdir}/stats.JOB" \
+ ${_opts} ${lm_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
+
+ # 4. Aggregate shape files
+ _opts=
+ for i in $(seq "${_nj}"); do
+ _opts+="--input_dir ${_logdir}/stats.${i} "
+ done
+ # shellcheck disable=SC2086
+ ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${lm_stats_dir}"
+
+ # Append the num-tokens at the last dimensions. This is used for batch-bins count
+ <"${lm_stats_dir}/train/text_shape" \
+ awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
+ >"${lm_stats_dir}/train/text_shape.${lm_token_type}"
+
+ <"${lm_stats_dir}/valid/text_shape" \
+ awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
+ >"${lm_stats_dir}/valid/text_shape.${lm_token_type}"
+ fi
+
+
+ if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ log "Stage 8: LM Training: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
+
+ _opts=
+ if [ -n "${lm_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.lm_train --print_config --optim adam
+ _opts+="--config ${lm_config} "
+ fi
+
+ if [ "${num_splits_lm}" -gt 1 ]; then
+ # If you met a memory error when parsing text files, this option may help you.
+ # The corpus is split into subsets and each subset is used for training one by one in order,
+ # so the memory footprint can be limited to the memory required for each dataset.
+
+ _split_dir="${lm_stats_dir}/splits${num_splits_lm}"
+ if [ ! -f "${_split_dir}/.done" ]; then
+ rm -f "${_split_dir}/.done"
+ ${python} -m espnet2.bin.split_scps \
+ --scps "${data_feats}/lm_train.txt" "${lm_stats_dir}/train/text_shape.${lm_token_type}" \
+ --num_splits "${num_splits_lm}" \
+ --output_dir "${_split_dir}"
+ touch "${_split_dir}/.done"
+ else
+ log "${_split_dir}/.done exists. Spliting is skipped"
+ fi
+
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/lm_train.txt,text,text "
+ _opts+="--train_shape_file ${_split_dir}/text_shape.${lm_token_type} "
+ _opts+="--multiple_iterator true "
+
+ else
+ _opts+="--train_data_path_and_name_and_type ${data_feats}/lm_train.txt,text,text "
+ _opts+="--train_shape_file ${lm_stats_dir}/train/text_shape.${lm_token_type} "
+ fi
+
+ # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+
+ log "Generate '${lm_exp}/run.sh'. You can resume the process from stage 8 using this script"
+ mkdir -p "${lm_exp}"; echo "${run_args} --stage 8 \"\$@\"; exit \$?" > "${lm_exp}/run.sh"; chmod +x "${lm_exp}/run.sh"
+
+ log "LM training started... log: '${lm_exp}/train.log'"
+ if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+ # SGE can't include "/" in a job name
+ jobname="$(basename ${lm_exp})"
+ else
+ jobname="${lm_exp}/train.log"
+ fi
+
+ mkdir -p ${lm_exp}
+ mkdir -p ${lm_exp}/log
+ INIT_FILE=${lm_exp}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+ lm_train.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --bpemodel ${bpemodel} \
+ --token_type ${token_type} \
+ --token_list ${token_list} \
+ --non_linguistic_symbols ${nlsyms_txt} \
+ --cleaner ${cleaner} \
+ --g2p ${g2p} \
+ --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
+ --valid_shape_file "${lm_stats_dir}/valid/text_shape.${lm_token_type}" \
+ --resume true \
+ --output_dir ${lm_exp} \
+ --config $lm_config \
+ --ngpu $ngpu \
+ --num_worker_count 1 \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $ngpu \
+ --dist_rank $rank \
+ --local_rank $local_rank \
+ ${_opts} 1> ${lm_exp}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+ fi
+
+
+ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
+ log "Stage 9: Calc perplexity: ${lm_test_text}"
+ _opts=
+ # TODO(kamo): Parallelize?
+ log "Perplexity calculation started... log: '${lm_exp}/perplexity_test/lm_calc_perplexity.log'"
+ # shellcheck disable=SC2086
+ CUDA_VISIBLE_DEVICES=${device}\
+ ${cuda_cmd} --gpu "${ngpu}" "${lm_exp}"/perplexity_test/lm_calc_perplexity.log \
+ ${python} -m funasr.bin.lm_calc_perplexity \
+ --ngpu "${ngpu}" \
+ --data_path_and_name_and_type "${lm_test_text},text,text" \
+ --train_config "${lm_exp}"/config.yaml \
+ --model_file "${lm_exp}/${inference_lm}" \
+ --output_dir "${lm_exp}/perplexity_test" \
+ ${_opts}
+ log "PPL: ${lm_test_text}: $(cat ${lm_exp}/perplexity_test/ppl)"
+
+ fi
+
+ else
+ log "Stage 7-9: Skip lm-related stages: use_lm=${use_lm}"
+ fi
+
+
+ if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
+ _asr_train_dir="${data_feats}/${train_set}"
+ _asr_valid_dir="${data_feats}/${valid_set}"
+ log "Stage 10: ASR collect stats: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+ _opts=
+ if [ -n "${asr_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.asr_train --print_config --optim adam
+ _opts+="--config ${asr_config} "
+ fi
+
+ _feats_type="$(<${_asr_train_dir}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ # "sound" supports "wav", "flac", etc.
+ _type=sound
+ fi
+ _opts+="--frontend_conf fs=${fs} "
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ _input_size="$(<${_asr_train_dir}/feats_dim)"
+ _opts+="--input_size=${_input_size} "
+ fi
+
+ # 1. Split the key file
+ _logdir="${asr_stats_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ # Get the minimum number among ${nj} and the number lines of input files
+ _nj=$(min "${nj}" "$(<${_asr_train_dir}/${_scp} wc -l)" "$(<${_asr_valid_dir}/${_scp} wc -l)")
+
+ key_file="${_asr_train_dir}/${_scp}"
+ split_scps=""
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/train.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ key_file="${_asr_valid_dir}/${_scp}"
+ split_scps=""
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/valid.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Generate run.sh
+ log "Generate '${asr_stats_dir}/run.sh'. You can resume the process from stage 9 using this script"
+ mkdir -p "${asr_stats_dir}"; echo "${run_args} --stage 9 \"\$@\"; exit \$?" > "${asr_stats_dir}/run.sh"; chmod +x "${asr_stats_dir}/run.sh"
+
+ # 3. Submit jobs
+ log "ASR collect-stats started... log: '${_logdir}/stats.*.log'"
+
+ # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
+ # but it's used only for deciding the sample ids.
+
+ # shellcheck disable=SC2086
+ ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
+ ${python} -m funasr.bin.asr_train \
+ --collect_stats true \
+ --mc true \
+ --use_preprocessor true \
+ --bpemodel "${bpemodel}" \
+ --token_type "${token_type}" \
+ --token_list "${token_list}" \
+ --split_with_space false \
+ --non_linguistic_symbols "${nlsyms_txt}" \
+ --cleaner "${cleaner}" \
+ --g2p "${g2p}" \
+ --train_data_path_and_name_and_type "${_asr_train_dir}/${_scp},speech,${_type}" \
+ --train_data_path_and_name_and_type "${_asr_train_dir}/text,text,text" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
+ --train_shape_file "${_logdir}/train.JOB.scp" \
+ --valid_shape_file "${_logdir}/valid.JOB.scp" \
+ --output_dir "${_logdir}/stats.JOB" \
+ ${_opts} ${asr_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
+
+ # 4. Aggregate shape files
+ _opts=
+ for i in $(seq "${_nj}"); do
+ _opts+="--input_dir ${_logdir}/stats.${i} "
+ done
+ # shellcheck disable=SC2086
+ ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${asr_stats_dir}"
+
+ # Append the num-tokens at the last dimensions. This is used for batch-bins count
+ <"${asr_stats_dir}/train/text_shape" \
+ awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
+ >"${asr_stats_dir}/train/text_shape.${token_type}"
+
+ <"${asr_stats_dir}/valid/text_shape" \
+ awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
+ >"${asr_stats_dir}/valid/text_shape.${token_type}"
+ fi
+
+
+ if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
+ _asr_train_dir="${data_feats}/${train_set}"
+ _asr_valid_dir="${data_feats}/${valid_set}"
+ log "Stage 11: ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+ _opts=
+ if [ -n "${asr_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.asr_train --print_config --optim adam
+ _opts+="--config ${asr_config} "
+ fi
+
+ _feats_type="$(<${_asr_train_dir}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ # "sound" supports "wav", "flac", etc.
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ _opts+="--frontend_conf fs=${fs} "
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ _input_size="$(<${_asr_train_dir}/feats_dim)"
+ _opts+="--input_size=${_input_size} "
+
+ fi
+ if [ "${feats_normalize}" = global_mvn ]; then
+ # Default normalization is utterance_mvn and changes to global_mvn
+ _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
+ fi
+
+ if [ "${num_splits_asr}" -gt 1 ]; then
+ # If you met a memory error when parsing text files, this option may help you.
+ # The corpus is split into subsets and each subset is used for training one by one in order,
+ # so the memory footprint can be limited to the memory required for each dataset.
+
+ _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
+ if [ ! -f "${_split_dir}/.done" ]; then
+ rm -f "${_split_dir}/.done"
+ ${python} -m espnet2.bin.split_scps \
+ --scps \
+ "${_asr_train_dir}/${_scp}" \
+ "${_asr_train_dir}/text" \
+ "${asr_stats_dir}/train/speech_shape" \
+ "${asr_stats_dir}/train/text_shape.${token_type}" \
+ --num_splits "${num_splits_asr}" \
+ --output_dir "${_split_dir}"
+ touch "${_split_dir}/.done"
+ else
+ log "${_split_dir}/.done exists. Spliting is skipped"
+ fi
+
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
+ _opts+="--train_shape_file ${_split_dir}/speech_shape "
+ _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
+ _opts+="--multiple_iterator true "
+
+ else
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
+ fi
+
+ # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
+ # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
+
+ # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+ log "ASR training started... log: '${asr_exp}/log/train.log'"
+ # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+ # # SGE can't include "/" in a job name
+ # jobname="$(basename ${asr_exp})"
+ # else
+ # jobname="${asr_exp}/train.log"
+ # fi
+
+ mkdir -p ${asr_exp}
+ mkdir -p ${asr_exp}/log
+ INIT_FILE=${asr_exp}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+ asr_train.py \
+ --mc true \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --bpemodel ${bpemodel} \
+ --token_type ${token_type} \
+ --token_list ${token_list} \
+ --split_with_space false \
+ --non_linguistic_symbols ${nlsyms_txt} \
+ --cleaner ${cleaner} \
+ --g2p ${g2p} \
+ --valid_data_path_and_name_and_type ${_asr_valid_dir}/${_scp},speech,${_type} \
+ --valid_data_path_and_name_and_type ${_asr_valid_dir}/text,text,text \
+ --valid_shape_file ${asr_stats_dir}/valid/speech_shape \
+ --valid_shape_file ${asr_stats_dir}/valid/text_shape.${token_type} \
+ --resume true \
+ --output_dir ${asr_exp} \
+ --config $asr_config \
+ --ngpu $ngpu \
+ --num_worker_count 1 \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $ngpu \
+ --dist_rank $rank \
+ --local_rank $local_rank \
+ ${_opts} 1> ${asr_exp}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+ fi
+
+ if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
+ _asr_train_dir="${data_feats}/${train_set}"
+ _asr_valid_dir="${data_feats}/${valid_set}"
+ log "Stage 12: SA-ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+ _opts=
+ if [ -n "${sa_asr_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.asr_train --print_config --optim adam
+ _opts+="--config ${sa_asr_config} "
+ fi
+
+ _feats_type="$(<${_asr_train_dir}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ # "sound" supports "wav", "flac", etc.
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ _opts+="--frontend_conf fs=${fs} "
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ _input_size="$(<${_asr_train_dir}/feats_dim)"
+ _opts+="--input_size=${_input_size} "
+
+ fi
+ if [ "${feats_normalize}" = global_mvn ]; then
+ # Default normalization is utterance_mvn and changes to global_mvn
+ _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
+ fi
+
+ if [ "${num_splits_asr}" -gt 1 ]; then
+ # If you met a memory error when parsing text files, this option may help you.
+ # The corpus is split into subsets and each subset is used for training one by one in order,
+ # so the memory footprint can be limited to the memory required for each dataset.
+
+ _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
+ if [ ! -f "${_split_dir}/.done" ]; then
+ rm -f "${_split_dir}/.done"
+ ${python} -m espnet2.bin.split_scps \
+ --scps \
+ "${_asr_train_dir}/${_scp}" \
+ "${_asr_train_dir}/text" \
+ "${asr_stats_dir}/train/speech_shape" \
+ "${asr_stats_dir}/train/text_shape.${token_type}" \
+ --num_splits "${num_splits_asr}" \
+ --output_dir "${_split_dir}"
+ touch "${_split_dir}/.done"
+ else
+ log "${_split_dir}/.done exists. Spliting is skipped"
+ fi
+
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/text_id_train,text_id,text_int "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/oracle_profile_padding.scp,profile,npy "
+ _opts+="--train_shape_file ${_split_dir}/speech_shape "
+ _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
+ _opts+="--multiple_iterator true "
+
+ else
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/oracle_profile_padding.scp,profile,npy "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text_id_train,text_id,text_int "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
+ fi
+
+ # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
+ # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
+
+ # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+ log "SA-ASR training started... log: '${sa_asr_exp}/log/train.log'"
+ # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+ # # SGE can't include "/" in a job name
+ # jobname="$(basename ${asr_exp})"
+ # else
+ # jobname="${asr_exp}/train.log"
+ # fi
+
+ mkdir -p ${sa_asr_exp}
+ mkdir -p ${sa_asr_exp}/log
+ INIT_FILE=${sa_asr_exp}/ddp_init
+
+ if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+ # download xvector extractor model file
+ python local/download_xvector_model.py exp
+ log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+ fi
+
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+ sa_asr_train.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --unused_parameters true \
+ --bpemodel ${bpemodel} \
+ --token_type ${token_type} \
+ --token_list ${token_list} \
+ --max_spk_num 4 \
+ --split_with_space false \
+ --non_linguistic_symbols ${nlsyms_txt} \
+ --cleaner ${cleaner} \
+ --g2p ${g2p} \
+ --allow_variable_data_keys true \
+ --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/oracle_profile_nopadding.scp,profile,npy" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/text_id_train,text_id,text_int" \
+ --valid_shape_file "${asr_stats_dir}/valid/speech_shape" \
+ --valid_shape_file "${asr_stats_dir}/valid/text_shape.${token_type}" \
+ --resume true \
+ --output_dir ${sa_asr_exp} \
+ --config $sa_asr_config \
+ --ngpu $ngpu \
+ --num_worker_count 1 \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $ngpu \
+ --dist_rank $rank \
+ --local_rank $local_rank \
+ ${_opts} 1> ${sa_asr_exp}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+ fi
+
+else
+ log "Skip the training stages"
+fi
+
+
+if ! "${skip_eval}"; then
+ if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
+ log "Stage 13: Decoding multi-talker ASR: training_dir=${asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$inference_nj
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 13 using this script"
+ mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 13 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${asr_exp}/${inference_tag}/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ echo $_nj
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/asr_inference.*.log'"
+
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode asr \
+ ${_opts}
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+
+ if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
+ log "Stage 14: Scoring multi-talker ASR"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${asr_exp}/${inference_tag}/${dset}"
+
+ python local/proce_text.py ${_data}/text ${_data}/text.proc
+ python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+ python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ done
+
+ fi
+
+ if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
+ log "Stage 15: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$inference_nj
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh'. You can resume the process from stage 15 using this script"
+ mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.oracle"; echo "${run_args} --stage 15 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+ # shellcheck disable=SC2086
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --data_path_and_name_and_type "${_data}/oracle_profile_nopadding.scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --allow_variable_data_keys true \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text text_id; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+ if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
+ log "Stage 16: Scoring SA-ASR (oracle profile)"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
+
+ python local/proce_text.py ${_data}/text ${_data}/text.proc
+ python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+ python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ python local/process_text_spk_merge.py ${_dir}
+ python local/process_text_spk_merge.py ${_data}
+
+ python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+ tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+ cat ${_dir}/text.cpcer.txt
+
+ done
+
+ fi
+
+ if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then
+ log "Stage 17: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$inference_nj
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
+ mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+ # shellcheck disable=SC2086
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --allow_variable_data_keys true \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text text_id; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+ if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then
+ log "Stage 18: Scoring SA-ASR (cluster profile)"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+
+ python local/proce_text.py ${_data}/text ${_data}/text.proc
+ python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+ python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ python local/process_text_spk_merge.py ${_dir}
+ python local/process_text_spk_merge.py ${_data}
+
+ python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+ tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+ cat ${_dir}/text.cpcer.txt
+
+ done
+
+ fi
+
+else
+ log "Skip the evaluation stages"
+fi
+
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_infer.sh
new file mode 100755
index 0000000..8e8148f
--- /dev/null
+++ b/egs/alimeeting/sa-asr/asr_local_infer.sh
@@ -0,0 +1,590 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+min() {
+ local a b
+ a=$1
+ for b in "$@"; do
+ if [ "${b}" -le "${a}" ]; then
+ a="${b}"
+ fi
+ done
+ echo "${a}"
+}
+SECONDS=0
+
+# General configuration
+stage=1 # Processes starts from the specified stage.
+stop_stage=10000 # Processes is stopped at the specified stage.
+skip_data_prep=false # Skip data preparation stages.
+skip_train=false # Skip training stages.
+skip_eval=false # Skip decoding and evaluation stages.
+skip_upload=true # Skip packing and uploading stages.
+ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
+num_nodes=1 # The number of nodes.
+nj=16 # The number of parallel jobs.
+inference_nj=16 # The number of parallel jobs in decoding.
+gpu_inference=false # Whether to perform gpu decoding.
+njob_infer=4
+dumpdir=dump2 # Directory to dump features.
+expdir=exp # Directory to save experiments.
+python=python3 # Specify python to execute espnet commands.
+device=0
+
+# Data preparation related
+local_data_opts= # The options given to local/data.sh.
+
+# Speed perturbation related
+speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
+
+# Feature extraction related
+feats_type=raw # Feature type (raw or fbank_pitch).
+audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
+fs=16000 # Sampling rate.
+min_wav_duration=0.1 # Minimum duration in second.
+max_wav_duration=20 # Maximum duration in second.
+
+# Tokenization related
+token_type=bpe # Tokenization type (char or bpe).
+nbpe=30 # The number of BPE vocabulary.
+bpemode=unigram # Mode of BPE (unigram or bpe).
+oov="<unk>" # Out of vocabulary symbol.
+blank="<blank>" # CTC blank symbol
+sos_eos="<sos/eos>" # sos and eos symbole
+bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
+bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
+bpe_char_cover=1.0 # character coverage when modeling BPE
+
+# Language model related
+use_lm=true # Use language model for ASR decoding.
+lm_tag= # Suffix to the result dir for language model training.
+lm_exp= # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored.
+lm_stats_dir= # Specify the direcotry path for LM statistics.
+lm_config= # Config for language model training.
+lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in lm config.
+use_word_lm=false # Whether to use word language model.
+num_splits_lm=1 # Number of splitting for lm corpus.
+# shellcheck disable=SC2034
+word_vocab_size=10000 # Size of word vocabulary.
+
+# ASR model related
+asr_tag= # Suffix to the result dir for asr model training.
+asr_exp= # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored.
+sa_asr_exp=
+asr_stats_dir= # Specify the direcotry path for ASR statistics.
+asr_config= # Config for asr model training.
+sa_asr_config=
+asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in asr config.
+feats_normalize=global_mvn # Normalizaton layer type.
+num_splits_asr=1 # Number of splitting for lm corpus.
+
+# Decoding related
+inference_tag= # Suffix to the result dir for decoding.
+inference_config= # Config for decoding.
+inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
+ # Note that it will overwrite args in inference config.
+sa_asr_inference_tag=
+sa_asr_inference_args=
+
+inference_lm=valid.loss.ave.pb # Language modle path for decoding.
+inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
+ # e.g.
+ # inference_asr_model=train.loss.best.pth
+ # inference_asr_model=3epoch.pth
+ # inference_asr_model=valid.acc.best.pth
+ # inference_asr_model=valid.loss.ave.pth
+inference_sa_asr_model=valid.acc_spk.ave.pb
+download_model= # Download a model from Model Zoo and use it for decoding.
+
+# [Task dependent] Set the datadir name created by local/data.sh
+train_set= # Name of training set.
+valid_set= # Name of validation set used for monitoring/tuning network training.
+test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
+bpe_train_text= # Text file path of bpe training set.
+lm_train_text= # Text file path of language model training set.
+lm_dev_text= # Text file path of language model development set.
+lm_test_text= # Text file path of language model evaluation set.
+nlsyms_txt=none # Non-linguistic symbol list if existing.
+cleaner=none # Text cleaner.
+g2p=none # g2p method (needed if token_type=phn).
+lang=zh # The language type of corpus.
+score_opts= # The options given to sclite scoring
+local_score_opts= # The options given to local/score.sh.
+
+help_message=$(cat << EOF
+Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
+
+Options:
+ # General configuration
+ --stage # Processes starts from the specified stage (default="${stage}").
+ --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
+ --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
+ --skip_train # Skip training stages (default="${skip_train}").
+ --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
+ --skip_upload # Skip packing and uploading stages (default="${skip_upload}").
+ --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
+ --num_nodes # The number of nodes (default="${num_nodes}").
+ --nj # The number of parallel jobs (default="${nj}").
+ --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
+ --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
+ --dumpdir # Directory to dump features (default="${dumpdir}").
+ --expdir # Directory to save experiments (default="${expdir}").
+ --python # Specify python to execute espnet commands (default="${python}").
+ --device # Which GPUs are use for local training (defalut="${device}").
+
+ # Data preparation related
+ --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
+
+ # Speed perturbation related
+ --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
+
+ # Feature extraction related
+ --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
+ --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
+ --fs # Sampling rate (default="${fs}").
+ --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
+ --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
+
+ # Tokenization related
+ --token_type # Tokenization type (char or bpe, default="${token_type}").
+ --nbpe # The number of BPE vocabulary (default="${nbpe}").
+ --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
+ --oov # Out of vocabulary symbol (default="${oov}").
+ --blank # CTC blank symbol (default="${blank}").
+ --sos_eos # sos and eos symbole (default="${sos_eos}").
+ --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
+ --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
+ --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
+
+ # Language model related
+ --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
+ --lm_exp # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored (default="${lm_exp}").
+ --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
+ --lm_config # Config for language model training (default="${lm_config}").
+ --lm_args # Arguments for language model training (default="${lm_args}").
+ # e.g., --lm_args "--max_epoch 10"
+ # Note that it will overwrite args in lm config.
+ --use_word_lm # Whether to use word language model (default="${use_word_lm}").
+ --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
+ --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
+
+ # ASR model related
+ --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
+ --asr_exp # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored (default="${asr_exp}").
+ --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
+ --asr_config # Config for asr model training (default="${asr_config}").
+ --asr_args # Arguments for asr model training (default="${asr_args}").
+ # e.g., --asr_args "--max_epoch 10"
+ # Note that it will overwrite args in asr config.
+ --feats_normalize # Normalizaton layer type (default="${feats_normalize}").
+ --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
+
+ # Decoding related
+ --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
+ --inference_config # Config for decoding (default="${inference_config}").
+ --inference_args # Arguments for decoding (default="${inference_args}").
+ # e.g., --inference_args "--lm_weight 0.1"
+ # Note that it will overwrite args in inference config.
+ --inference_lm # Language modle path for decoding (default="${inference_lm}").
+ --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
+ --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
+
+ # [Task dependent] Set the datadir name created by local/data.sh
+ --train_set # Name of training set (required).
+ --valid_set # Name of validation set used for monitoring/tuning network training (required).
+ --test_sets # Names of test sets.
+ # Multiple items (e.g., both dev and eval sets) can be specified (required).
+ --bpe_train_text # Text file path of bpe training set.
+ --lm_train_text # Text file path of language model training set.
+ --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
+ --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
+ --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
+ --cleaner # Text cleaner (default="${cleaner}").
+ --g2p # g2p method (default="${g2p}").
+ --lang # The language type of corpus (default=${lang}).
+ --score_opts # The options given to sclite scoring (default="{score_opts}").
+ --local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
+EOF
+)
+
+log "$0 $*"
+# Save command line args for logging (they will be lost after utils/parse_options.sh)
+run_args=$(python -m funasr.utils.cli_utils $0 "$@")
+. utils/parse_options.sh
+
+if [ $# -ne 0 ]; then
+ log "${help_message}"
+ log "Error: No positional arguments are required."
+ exit 2
+fi
+
+. ./path.sh
+
+
+# Check required arguments
+[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
+[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
+[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
+
+# Check feature type
+if [ "${feats_type}" = raw ]; then
+ data_feats=${dumpdir}/raw
+elif [ "${feats_type}" = fbank_pitch ]; then
+ data_feats=${dumpdir}/fbank_pitch
+elif [ "${feats_type}" = fbank ]; then
+ data_feats=${dumpdir}/fbank
+elif [ "${feats_type}" == extracted ]; then
+ data_feats=${dumpdir}/extracted
+else
+ log "${help_message}"
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+fi
+
+# Use the same text as ASR for bpe training if not specified.
+[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
+# Use the text of the 1st evaldir if lm_test is not specified
+[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
+
+# Check tokenization type
+if [ "${lang}" != noinfo ]; then
+ token_listdir=data/${lang}_token_list
+else
+ token_listdir=data/token_list
+fi
+bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
+bpeprefix="${bpedir}"/bpe
+bpemodel="${bpeprefix}".model
+bpetoken_list="${bpedir}"/tokens.txt
+chartoken_list="${token_listdir}"/char/tokens.txt
+# NOTE: keep for future development.
+# shellcheck disable=SC2034
+wordtoken_list="${token_listdir}"/word/tokens.txt
+
+if [ "${token_type}" = bpe ]; then
+ token_list="${bpetoken_list}"
+elif [ "${token_type}" = char ]; then
+ token_list="${chartoken_list}"
+ bpemodel=none
+elif [ "${token_type}" = word ]; then
+ token_list="${wordtoken_list}"
+ bpemodel=none
+else
+ log "Error: not supported --token_type '${token_type}'"
+ exit 2
+fi
+if ${use_word_lm}; then
+ log "Error: Word LM is not supported yet"
+ exit 2
+
+ lm_token_list="${wordtoken_list}"
+ lm_token_type=word
+else
+ lm_token_list="${token_list}"
+ lm_token_type="${token_type}"
+fi
+
+
+# Set tag for naming of model directory
+if [ -z "${asr_tag}" ]; then
+ if [ -n "${asr_config}" ]; then
+ asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
+ else
+ asr_tag="train_${feats_type}"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ asr_tag+="_${lang}_${token_type}"
+ else
+ asr_tag+="_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${asr_args}" ]; then
+ asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_tag+="_sp"
+ fi
+fi
+if [ -z "${lm_tag}" ]; then
+ if [ -n "${lm_config}" ]; then
+ lm_tag="$(basename "${lm_config}" .yaml)"
+ else
+ lm_tag="train"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ lm_tag+="_${lang}_${lm_token_type}"
+ else
+ lm_tag+="_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${lm_args}" ]; then
+ lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+fi
+
+# The directory used for collect-stats mode
+if [ -z "${asr_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
+ else
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_stats_dir+="${nbpe}"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_stats_dir+="_sp"
+ fi
+fi
+if [ -z "${lm_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
+ else
+ lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_stats_dir+="${nbpe}"
+ fi
+fi
+# The directory used for training commands
+if [ -z "${asr_exp}" ]; then
+ asr_exp="${expdir}/asr_${asr_tag}"
+fi
+if [ -z "${lm_exp}" ]; then
+ lm_exp="${expdir}/lm_${lm_tag}"
+fi
+
+
+if [ -z "${inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ inference_tag=inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${inference_args}" ]; then
+ inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+if [ -z "${sa_asr_inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ sa_asr_inference_tag=sa_asr_inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${sa_asr_inference_args}" ]; then
+ sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+train_cmd="run.pl"
+cuda_cmd="run.pl"
+decode_cmd="run.pl"
+
+# ========================== Main stages start from here. ==========================
+
+if ! "${skip_data_prep}"; then
+
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ if [ "${feats_type}" = raw ]; then
+ log "Stage 1: Format wav.scp: data/ -> ${data_feats}"
+
+ # ====== Recreating "wav.scp" ======
+ # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
+ # shouldn't be used in training process.
+ # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
+ # and it can also change the audio-format and sampling rate.
+ # If nothing is need, then format_wav_scp.sh does nothing:
+ # i.e. the input file format and rate is same as the output.
+
+ for dset in "${test_sets}" ; do
+
+ _suf=""
+
+ utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+
+ rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
+ _opts=
+ if [ -e data/"${dset}"/segments ]; then
+ # "segments" is used for splitting wav files which are written in "wav".scp
+ # into utterances. The file format of segments:
+ # <segment_id> <record_id> <start_time> <end_time>
+ # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
+ # Where the time is written in seconds.
+ _opts+="--segments data/${dset}/segments "
+ fi
+ # shellcheck disable=SC2086
+ scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
+ "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
+
+ echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
+ done
+
+ else
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+ fi
+ fi
+
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ log "Stage 2: Generate speaker profile by spectral-cluster"
+ mkdir -p "profile_log"
+ for dset in "${test_sets}"; do
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+ log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
+ done
+ fi
+
+else
+ log "Skip the stages for data preparation"
+fi
+
+
+# ========================== Data preparation is done here. ==========================
+
+if ! "${skip_eval}"; then
+
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$njob_infer
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
+ mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+ # shellcheck disable=SC2086
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --allow_variable_data_keys true \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text text_id; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ log "Stage 4: Generate SA-ASR results (cluster profile)"
+
+ for dset in ${test_sets}; do
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+
+ python local/process_text_spk_merge.py ${_dir}
+ done
+
+ fi
+
+else
+ log "Skip the evaluation stages"
+fi
+
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
new file mode 100644
index 0000000..88fdbc2
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
@@ -0,0 +1,6 @@
+beam_size: 20
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.6
+lm_weight: 0.3
diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
new file mode 100644
index 0000000..a8c9968
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
@@ -0,0 +1,88 @@
+# network architecture
+frontend: default
+frontend_conf:
+ n_fft: 400
+ win_length: 400
+ hop_length: 160
+ use_channel: 0
+
+# encoder related
+encoder: conformer
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ rel_pos_type: latest
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+# decoder related
+decoder: transformer
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# ctc related
+ctc_conf:
+ ignore_nan_grad: true
+
+# hybrid CTC/attention
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+# minibatch related
+batch_type: numel
+batch_bins: 10000000 # reduce/increase this number according to your GPU memory
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+ - valid
+ - acc
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
new file mode 100644
index 0000000..68520ae
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
@@ -0,0 +1,29 @@
+lm: transformer
+lm_conf:
+ pos_enc: null
+ embed_unit: 128
+ att_unit: 512
+ head: 8
+ unit: 2048
+ layer: 16
+ dropout_rate: 0.1
+
+# optimization related
+grad_clip: 5.0
+batch_type: numel
+batch_bins: 500000 # 4gpus * 500000
+accum_grad: 1
+max_epoch: 15 # 15epoch is enougth
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+best_model_criterion:
+- - valid
+ - loss
+ - min
+keep_nbest_models: 10 # 10 is good.
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
new file mode 100644
index 0000000..e91db18
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
@@ -0,0 +1,116 @@
+# network architecture
+frontend: default
+frontend_conf:
+ n_fft: 400
+ win_length: 400
+ hop_length: 160
+ use_channel: 0
+
+# encoder related
+asr_encoder: conformer
+asr_encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+spk_encoder: resnet34_diar
+spk_encoder_conf:
+ use_head_conv: true
+ batchnorm_momentum: 0.5
+ use_head_maxpool: false
+ num_nodes_pooling_layer: 256
+ layers_in_block:
+ - 3
+ - 4
+ - 6
+ - 3
+ filters_in_block:
+ - 32
+ - 64
+ - 128
+ - 256
+ pooling_type: statistic
+ num_nodes_resnet1: 256
+ num_nodes_last_layer: 256
+ batchnorm_momentum: 0.5
+
+# decoder related
+decoder: sa_decoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ asr_num_blocks: 6
+ spk_num_blocks: 3
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# hybrid CTC/attention
+model_conf:
+ spk_weight: 0.5
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+ctc_conf:
+ ignore_nan_grad: true
+
+# minibatch related
+batch_type: numel
+batch_bins: 10000000
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 60
+val_scheduler_criterion:
+ - valid
+ - loss
+best_model_criterion:
+- - valid
+ - acc
+ - max
+- - valid
+ - acc_spk
+ - max
+- - valid
+ - loss
+ - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 8000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
new file mode 100755
index 0000000..8151bae
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
@@ -0,0 +1,162 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+help_messge=$(cat << EOF
+Usage: $0
+
+Options:
+ --no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
+ --tgt (string): Which set to process, test or train.
+EOF
+)
+
+SECONDS=0
+tgt=Train #Train or Eval
+
+
+log "$0 $*"
+echo $tgt
+. ./utils/parse_options.sh
+
+. ./path.sh
+
+AliMeeting="${PWD}/dataset"
+
+if [ $# -gt 2 ]; then
+ log "${help_message}"
+ exit 2
+fi
+
+
+if [ ! -d "${AliMeeting}" ]; then
+ log "Error: ${AliMeeting} is empty."
+ exit 2
+fi
+
+# To absolute path
+AliMeeting=$(cd ${AliMeeting}; pwd)
+echo $AliMeeting
+far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
+near_raw_dir=${AliMeeting}/${tgt}_Ali_near/
+
+far_dir=data/local/${tgt}_Ali_far
+near_dir=data/local/${tgt}_Ali_near
+far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
+mkdir -p $far_single_speaker_dir
+
+stage=1
+stop_stage=4
+mkdir -p $far_dir
+mkdir -p $near_dir
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ log "stage 1:process alimeeting near dir"
+
+ find -L $near_raw_dir/audio_dir -iname "*.wav" > $near_dir/wavlist
+ awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
+ find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
+ n1_wav=$(wc -l < $near_dir/wavlist)
+ n2_text=$(wc -l < $near_dir/textgrid.flist)
+ log near file found $n1_wav wav and $n2_text text.
+
+ paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
+
+ # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
+ cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
+
+ python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
+ cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
+ utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
+ #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
+ #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
+ utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
+ utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
+ sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
+ sed -e 's/锛�//g' $near_dir/tmp1> $near_dir/tmp2
+ sed -e 's/锛�//g' $near_dir/tmp2> $near_dir/text
+
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ log "stage 2:process alimeeting far dir"
+
+ find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
+ awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
+ find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
+ n1_wav=$(wc -l < $far_dir/wavlist)
+ n2_text=$(wc -l < $far_dir/textgrid.flist)
+ log far file found $n1_wav wav and $n2_text text.
+
+ paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
+
+ cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
+
+ python local/alimeeting_process_overlap_force.py --path $far_dir \
+ --no-overlap false --mars True \
+ --overlap_length 0.8 --max_length 7
+
+ cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
+ #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
+
+ utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
+ sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
+ sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
+ sed -e 's/锛�//g' $far_dir/tmp2> $far_dir/tmp3
+ sed -e 's/锛�//g' $far_dir/tmp3> $far_dir/text
+fi
+
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ log "stage 3: finali data process"
+
+ utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
+ utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+
+ sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
+ sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+
+ # remove space in text
+ for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
+ cp -r $far_dir/* $far_single_speaker_dir
+ mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
+ paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
+ python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
+
+ cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
+ utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+
+ ./utils/fix_data_dir.sh $far_single_speaker_dir
+ utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+
+ # remove space in text
+ for x in ${tgt}_Ali_far_single_speaker; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
new file mode 100755
index 0000000..382a056
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
@@ -0,0 +1,129 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+help_messge=$(cat << EOF
+Usage: $0
+
+Options:
+ --no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
+ --tgt (string): Which set to process, test or train.
+EOF
+)
+
+SECONDS=0
+tgt=Train #Train or Eval
+
+
+log "$0 $*"
+echo $tgt
+. ./utils/parse_options.sh
+
+. ./path.sh
+
+AliMeeting="${PWD}/dataset"
+
+if [ $# -gt 2 ]; then
+ log "${help_message}"
+ exit 2
+fi
+
+
+if [ ! -d "${AliMeeting}" ]; then
+ log "Error: ${AliMeeting} is empty."
+ exit 2
+fi
+
+# To absolute path
+AliMeeting=$(cd ${AliMeeting}; pwd)
+echo $AliMeeting
+far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
+
+far_dir=data/local/${tgt}_Ali_far
+far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
+mkdir -p $far_single_speaker_dir
+
+stage=1
+stop_stage=3
+mkdir -p $far_dir
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ log "stage 1:process alimeeting far dir"
+
+ find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
+ awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
+ find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
+ n1_wav=$(wc -l < $far_dir/wavlist)
+ n2_text=$(wc -l < $far_dir/textgrid.flist)
+ log far file found $n1_wav wav and $n2_text text.
+
+ paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
+
+ cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
+
+ python local/alimeeting_process_overlap_force.py --path $far_dir \
+ --no-overlap false --mars True \
+ --overlap_length 0.8 --max_length 7
+
+ cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
+ #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
+
+ utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
+ sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
+ sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
+ sed -e 's/锛�//g' $far_dir/tmp2> $far_dir/tmp3
+ sed -e 's/锛�//g' $far_dir/tmp3> $far_dir/text
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ log "stage 2: finali data process"
+
+ utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+
+ sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
+ sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+
+ # remove space in text
+ for x in ${tgt}_Ali_far; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ log "stage 3:process alimeeting far dir (single speaker by oracal time strap)"
+ cp -r $far_dir/* $far_single_speaker_dir
+ mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
+ paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
+ python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
+
+ cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
+ utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+
+ ./utils/fix_data_dir.sh $far_single_speaker_dir
+ utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+
+ # remove space in text
+ for x in ${tgt}_Ali_far_single_speaker; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
new file mode 100755
index 0000000..8ece757
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
@@ -0,0 +1,235 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+class Segment(object):
+ def __init__(self, uttid, spkr, stime, etime, text):
+ self.uttid = uttid
+ self.spkr = spkr
+ self.spkr_all = uttid+"-"+spkr
+ self.stime = round(stime, 2)
+ self.etime = round(etime, 2)
+ self.text = text
+ self.spk_text = {uttid+"-"+spkr: text}
+
+ def change_stime(self, time):
+ self.stime = time
+
+ def change_etime(self, time):
+ self.etime = time
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ parser.add_argument(
+ "--no-overlap",
+ type=strtobool,
+ default=False,
+ help="Whether to ignore the overlapping utterances.",
+ )
+ parser.add_argument(
+ "--max_length",
+ default=100000,
+ type=float,
+ help="overlap speech max time,if longger than max length should cut",
+ )
+ parser.add_argument(
+ "--overlap_length",
+ default=1,
+ type=float,
+ help="if length longer than max length, speech overlength shorter, is cut",
+ )
+ parser.add_argument(
+ "--mars",
+ type=strtobool,
+ default=False,
+ help="Whether to process mars data set.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def preposs_overlap(segments,max_length,overlap_length):
+ new_segments = []
+ # init a helper list to store all overlap segments
+ tmp_segments = segments[0]
+ min_stime = segments[0].stime
+ max_etime = segments[0].etime
+ overlap_length_big = 1.5
+ max_length_big = 15
+ for i in range(1, len(segments)):
+ if segments[i].stime >= max_etime:
+ # doesn't overlap with preivous segments
+ new_segments.append(tmp_segments)
+ tmp_segments = segments[i]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+ else:
+ # overlap with previous segments
+ dur_time = max_etime - min_stime
+ if dur_time < max_length:
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+ tmp_segments.stime = min_stime
+ tmp_segments.etime = max_etime
+ tmp_segments.text = tmp_segments.text + "src" + segments[i].text
+ spk_name =segments[i].uttid +"-" + segments[i].spkr
+ if spk_name in tmp_segments.spk_text:
+ tmp_segments.spk_text[spk_name] += segments[i].text
+ else:
+ tmp_segments.spk_text[spk_name] = segments[i].text
+ tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
+ else:
+ overlap_time = max_etime - segments[i].stime
+ if dur_time < max_length_big:
+ overlap_length_option = overlap_length
+ else:
+ overlap_length_option = overlap_length_big
+ if overlap_time > overlap_length_option:
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+ tmp_segments.stime = min_stime
+ tmp_segments.etime = max_etime
+ tmp_segments.text = tmp_segments.text + "src" + segments[i].text
+ spk_name =segments[i].uttid +"-" + segments[i].spkr
+ if spk_name in tmp_segments.spk_text:
+ tmp_segments.spk_text[spk_name] += segments[i].text
+ else:
+ tmp_segments.spk_text[spk_name] = segments[i].text
+ tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
+ else:
+ new_segments.append(tmp_segments)
+ tmp_segments = segments[i]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+
+ return new_segments
+
+def filter_overlap(segments):
+ new_segments = []
+ # init a helper list to store all overlap segments
+ tmp_segments = [segments[0]]
+ min_stime = segments[0].stime
+ max_etime = segments[0].etime
+
+ for i in range(1, len(segments)):
+ if segments[i].stime >= max_etime:
+ # doesn't overlap with preivous segments
+ if len(tmp_segments) == 1:
+ new_segments.append(tmp_segments[0])
+ # TODO: for multi-spkr asr, we can reset the stime/etime to
+ # min_stime/max_etime for generating a max length mixutre speech
+ tmp_segments = [segments[i]]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+ else:
+ # overlap with previous segments
+ tmp_segments.append(segments[i])
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+
+ return new_segments
+
+
+def main(args):
+ wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
+ textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+
+ # get the path of textgrid file for each utterance
+ utt2textgrid = {}
+ for line in textgrid_flist:
+ path = Path(line.strip())
+ uttid = path.stem
+ utt2textgrid[uttid] = path
+
+ # parse the textgrid file for each utterance
+ all_segments = []
+ for line in wav_scp:
+ uttid = line.strip().split(" ")[0]
+ uttid_part=uttid
+ if args.mars == True:
+ uttid_list = uttid.split("_")
+ uttid_part= uttid_list[0]+"_"+uttid_list[1]
+ if uttid_part not in utt2textgrid:
+ print("%s doesn't have transcription" % uttid)
+ continue
+
+ segments = []
+ tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
+ for i in range(tg.__len__()):
+ for j in range(tg[i].__len__()):
+ if tg[i][j].mark:
+ segments.append(
+ Segment(
+ uttid,
+ tg[i].name,
+ tg[i][j].minTime,
+ tg[i][j].maxTime,
+ tg[i][j].mark.strip(),
+ )
+ )
+
+ segments = sorted(segments, key=lambda x: x.stime)
+
+ if args.no_overlap:
+ segments = filter_overlap(segments)
+ else:
+ segments = preposs_overlap(segments,args.max_length,args.overlap_length)
+ all_segments += segments
+
+ wav_scp.close()
+ textgrid_flist.close()
+
+ segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
+ utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
+ text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
+ utt2spk_file_fifo = codecs.open(Path(args.path) / "utt2spk_all_fifo", "w", "utf-8")
+
+ for i in range(len(all_segments)):
+ utt_name = "%s-%s-%07d-%07d" % (
+ all_segments[i].uttid,
+ all_segments[i].spkr,
+ all_segments[i].stime * 100,
+ all_segments[i].etime * 100,
+ )
+
+ segments_file.write(
+ "%s %s %.2f %.2f\n"
+ % (
+ utt_name,
+ all_segments[i].uttid,
+ all_segments[i].stime,
+ all_segments[i].etime,
+ )
+ )
+ utt2spk_file.write(
+ "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
+ )
+ utt2spk_file_fifo.write(
+ "%s %s\n" % (utt_name, all_segments[i].spkr_all)
+ )
+ text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
+
+ segments_file.close()
+ utt2spk_file.close()
+ text_file.close()
+ utt2spk_file_fifo.close()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
new file mode 100755
index 0000000..81c1965
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+class Segment(object):
+ def __init__(self, uttid, spkr, stime, etime, text):
+ self.uttid = uttid
+ self.spkr = spkr
+ self.stime = round(stime, 2)
+ self.etime = round(etime, 2)
+ self.text = text
+
+ def change_stime(self, time):
+ self.stime = time
+
+ def change_etime(self, time):
+ self.etime = time
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ parser.add_argument(
+ "--no-overlap",
+ type=strtobool,
+ default=False,
+ help="Whether to ignore the overlapping utterances.",
+ )
+ parser.add_argument(
+ "--mars",
+ type=strtobool,
+ default=False,
+ help="Whether to process mars data set.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def filter_overlap(segments):
+ new_segments = []
+ # init a helper list to store all overlap segments
+ tmp_segments = [segments[0]]
+ min_stime = segments[0].stime
+ max_etime = segments[0].etime
+
+ for i in range(1, len(segments)):
+ if segments[i].stime >= max_etime:
+ # doesn't overlap with preivous segments
+ if len(tmp_segments) == 1:
+ new_segments.append(tmp_segments[0])
+ # TODO: for multi-spkr asr, we can reset the stime/etime to
+ # min_stime/max_etime for generating a max length mixutre speech
+ tmp_segments = [segments[i]]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+ else:
+ # overlap with previous segments
+ tmp_segments.append(segments[i])
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+
+ return new_segments
+
+
+def main(args):
+ wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
+ textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+
+ # get the path of textgrid file for each utterance
+ utt2textgrid = {}
+ for line in textgrid_flist:
+ path = Path(line.strip())
+ uttid = path.stem
+ utt2textgrid[uttid] = path
+
+ # parse the textgrid file for each utterance
+ all_segments = []
+ for line in wav_scp:
+ uttid = line.strip().split(" ")[0]
+ uttid_part=uttid
+ if args.mars == True:
+ uttid_list = uttid.split("_")
+ uttid_part= uttid_list[0]+"_"+uttid_list[1]
+ if uttid_part not in utt2textgrid:
+ print("%s doesn't have transcription" % uttid)
+ continue
+ #pdb.set_trace()
+ segments = []
+ try:
+ tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
+ except:
+ pdb.set_trace()
+ for i in range(tg.__len__()):
+ for j in range(tg[i].__len__()):
+ if tg[i][j].mark:
+ segments.append(
+ Segment(
+ uttid,
+ tg[i].name,
+ tg[i][j].minTime,
+ tg[i][j].maxTime,
+ tg[i][j].mark.strip(),
+ )
+ )
+
+ segments = sorted(segments, key=lambda x: x.stime)
+
+ if args.no_overlap:
+ segments = filter_overlap(segments)
+
+ all_segments += segments
+
+ wav_scp.close()
+ textgrid_flist.close()
+
+ segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
+ utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
+ text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
+
+ for i in range(len(all_segments)):
+ utt_name = "%s-%s-%07d-%07d" % (
+ all_segments[i].uttid,
+ all_segments[i].spkr,
+ all_segments[i].stime * 100,
+ all_segments[i].etime * 100,
+ )
+
+ segments_file.write(
+ "%s %s %.2f %.2f\n"
+ % (
+ utt_name,
+ all_segments[i].uttid,
+ all_segments[i].stime,
+ all_segments[i].etime,
+ )
+ )
+ utt2spk_file.write(
+ "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
+ )
+ text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
+
+ segments_file.close()
+ utt2spk_file.close()
+ text_file.close()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa-asr/local/compute_cpcer.py
new file mode 100644
index 0000000..f4d4a79
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/compute_cpcer.py
@@ -0,0 +1,91 @@
+import editdistance
+import sys
+import os
+from itertools import permutations
+
+
+def load_transcripts(file_path):
+ trans_list = []
+ for one_line in open(file_path, "rt"):
+ meeting_id, trans = one_line.strip().split(" ")
+ trans_list.append((meeting_id.strip(), trans.strip()))
+
+ return trans_list
+
+def calc_spk_trans(trans):
+ spk_trans_ = [x.strip() for x in trans.split("$")]
+ spk_trans = []
+ for i in range(len(spk_trans_)):
+ spk_trans.append((str(i), spk_trans_[i]))
+ return spk_trans
+
+def calc_cer(ref_trans, hyp_trans):
+ ref_spk_trans = calc_spk_trans(ref_trans)
+ hyp_spk_trans = calc_spk_trans(hyp_trans)
+ ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
+ num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
+ ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
+ hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
+
+ errors, counts, permutes = [], [], []
+ min_error = 0
+ cost_dict = {}
+ for perm in permutations(range(num_spk)):
+ flag = True
+ p_err, p_count = 0, 0
+ for idx, p in enumerate(perm):
+ if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
+ flag = False
+ break
+ cost_key = "{}-{}".format(idx, p)
+ if cost_key in cost_dict:
+ _e = cost_dict[cost_key]
+ else:
+ _e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
+ cost_dict[cost_key] = _e
+ if _e > min_error > 0:
+ flag = False
+ break
+ p_err += _e
+ p_count += len(ref_spk_trans[idx][1])
+
+ if flag:
+ if p_err < min_error or min_error == 0:
+ min_error = p_err
+
+ errors.append(p_err)
+ counts.append(p_count)
+ permutes.append(perm)
+
+ sd_cer = [(err, cnt, err/cnt, permute)
+ for err, cnt, permute in zip(errors, counts, permutes)]
+ # import ipdb;ipdb.set_trace()
+ best_rst = min(sd_cer, key=lambda x: x[2])
+
+ return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
+
+
+def main():
+ ref=sys.argv[1]
+ hyp=sys.argv[2]
+ result_path=sys.argv[3]
+ ref_list = load_transcripts(ref)
+ hyp_list = load_transcripts(hyp)
+ result_file = open(result_path,'w')
+ error, count = 0, 0
+ for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
+ assert ref_id == hyp_id
+ mid = ref_id
+ dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
+ error, count = error + dist, count + length
+ result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
+
+ # print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
+
+ result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
+ result_file.close()
+ # print("Sum/Avg: {:.2f}".format(error / count * 100.0))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/alimeeting/sa-asr/local/compute_wer.py b/egs/alimeeting/sa-asr/local/compute_wer.py
new file mode 100755
index 0000000..349a3f6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/compute_wer.py
@@ -0,0 +1,157 @@
+import os
+import numpy as np
+import sys
+
+def compute_wer(ref_file,
+ hyp_file,
+ cer_detail_file):
+ rst = {
+ 'Wrd': 0,
+ 'Corr': 0,
+ 'Ins': 0,
+ 'Del': 0,
+ 'Sub': 0,
+ 'Snt': 0,
+ 'Err': 0.0,
+ 'S.Err': 0.0,
+ 'wrong_words': 0,
+ 'wrong_sentences': 0
+ }
+
+ hyp_dict = {}
+ ref_dict = {}
+ with open(hyp_file, 'r') as hyp_reader:
+ for line in hyp_reader:
+ key = line.strip().split()[0]
+ value = line.strip().split()[1:]
+ hyp_dict[key] = value
+ with open(ref_file, 'r') as ref_reader:
+ for line in ref_reader:
+ key = line.strip().split()[0]
+ value = line.strip().split()[1:]
+ ref_dict[key] = value
+
+ cer_detail_writer = open(cer_detail_file, 'w')
+ for hyp_key in hyp_dict:
+ if hyp_key in ref_dict:
+ out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
+ rst['Wrd'] += out_item['nwords']
+ rst['Corr'] += out_item['cor']
+ rst['wrong_words'] += out_item['wrong']
+ rst['Ins'] += out_item['ins']
+ rst['Del'] += out_item['del']
+ rst['Sub'] += out_item['sub']
+ rst['Snt'] += 1
+ if out_item['wrong'] > 0:
+ rst['wrong_sentences'] += 1
+ cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
+ cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
+ cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
+
+ if rst['Wrd'] > 0:
+ rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
+ if rst['Snt'] > 0:
+ rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
+
+ cer_detail_writer.write('\n')
+ cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
+ ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
+ cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
+ cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
+
+
+def compute_wer_by_line(hyp,
+ ref):
+ hyp = list(map(lambda x: x.lower(), hyp))
+ ref = list(map(lambda x: x.lower(), ref))
+
+ len_hyp = len(hyp)
+ len_ref = len(ref)
+
+ cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
+
+ ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
+
+ for i in range(len_hyp + 1):
+ cost_matrix[i][0] = i
+ for j in range(len_ref + 1):
+ cost_matrix[0][j] = j
+
+ for i in range(1, len_hyp + 1):
+ for j in range(1, len_ref + 1):
+ if hyp[i - 1] == ref[j - 1]:
+ cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
+ else:
+ substitution = cost_matrix[i - 1][j - 1] + 1
+ insertion = cost_matrix[i - 1][j] + 1
+ deletion = cost_matrix[i][j - 1] + 1
+
+ compare_val = [substitution, insertion, deletion]
+
+ min_val = min(compare_val)
+ operation_idx = compare_val.index(min_val) + 1
+ cost_matrix[i][j] = min_val
+ ops_matrix[i][j] = operation_idx
+
+ match_idx = []
+ i = len_hyp
+ j = len_ref
+ rst = {
+ 'nwords': len_ref,
+ 'cor': 0,
+ 'wrong': 0,
+ 'ins': 0,
+ 'del': 0,
+ 'sub': 0
+ }
+ while i >= 0 or j >= 0:
+ i_idx = max(0, i)
+ j_idx = max(0, j)
+
+ if ops_matrix[i_idx][j_idx] == 0: # correct
+ if i - 1 >= 0 and j - 1 >= 0:
+ match_idx.append((j - 1, i - 1))
+ rst['cor'] += 1
+
+ i -= 1
+ j -= 1
+
+ elif ops_matrix[i_idx][j_idx] == 2: # insert
+ i -= 1
+ rst['ins'] += 1
+
+ elif ops_matrix[i_idx][j_idx] == 3: # delete
+ j -= 1
+ rst['del'] += 1
+
+ elif ops_matrix[i_idx][j_idx] == 1: # substitute
+ i -= 1
+ j -= 1
+ rst['sub'] += 1
+
+ if i < 0 and j >= 0:
+ rst['del'] += 1
+ elif j < 0 and i >= 0:
+ rst['ins'] += 1
+
+ match_idx.reverse()
+ wrong_cnt = cost_matrix[len_hyp][len_ref]
+ rst['wrong'] = wrong_cnt
+
+ return rst
+
+def print_cer_detail(rst):
+ return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
+
+if __name__ == '__main__':
+ if len(sys.argv) != 4:
+ print("usage : python compute-wer.py test.ref test.hyp test.wer")
+ sys.exit(0)
+
+ ref_file = sys.argv[1]
+ hyp_file = sys.argv[2]
+ cer_detail_file = sys.argv[3]
+ compute_wer(ref_file, hyp_file, cer_detail_file)
diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa-asr/local/download_xvector_model.py
new file mode 100644
index 0000000..7da6559
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/download_xvector_model.py
@@ -0,0 +1,6 @@
+from modelscope.hub.snapshot_download import snapshot_download
+import sys
+
+
+cache_dir = sys.argv[1]
+model_dir = snapshot_download('damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch', cache_dir=cache_dir)
diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
new file mode 100644
index 0000000..e606162
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
@@ -0,0 +1,22 @@
+import sys
+if __name__=="__main__":
+ uttid_path=sys.argv[1]
+ src_path=sys.argv[2]
+ tgt_path=sys.argv[3]
+ uttid_file=open(uttid_path,'r')
+ uttid_line=uttid_file.readlines()
+ uttid_file.close()
+ ori_utt2spk_all_fifo_file=open(src_path+'/utt2spk_all_fifo','r')
+ ori_utt2spk_all_fifo_line=ori_utt2spk_all_fifo_file.readlines()
+ ori_utt2spk_all_fifo_file.close()
+ new_utt2spk_all_fifo_file=open(tgt_path+'/utt2spk_all_fifo','w')
+
+ uttid_list=[]
+ for line in uttid_line:
+ uttid_list.append(line.strip())
+
+ for line in ori_utt2spk_all_fifo_line:
+ if line.strip().split(' ')[0] in uttid_list:
+ new_utt2spk_all_fifo_file.write(line)
+
+ new_utt2spk_all_fifo_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
new file mode 100644
index 0000000..c37abf9
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
@@ -0,0 +1,167 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+import sys
+import os
+import soundfile
+from itertools import permutations
+from sklearn.metrics.pairwise import cosine_similarity
+from sklearn import cluster
+
+
+def custom_spectral_clustering(affinity, min_n_clusters=2, max_n_clusters=4, refine=True,
+ threshold=0.995, laplacian_type="graph_cut"):
+ if refine:
+ # Symmetrization
+ affinity = np.maximum(affinity, np.transpose(affinity))
+ # Diffusion
+ affinity = np.matmul(affinity, np.transpose(affinity))
+ # Row-wise max normalization
+ row_max = affinity.max(axis=1, keepdims=True)
+ affinity = affinity / row_max
+
+ # a) Construct S and set diagonal elements to 0
+ affinity = affinity - np.diag(np.diag(affinity))
+ # b) Compute Laplacian matrix L and perform normalization:
+ degree = np.diag(np.sum(affinity, axis=1))
+ laplacian = degree - affinity
+ if laplacian_type == "random_walk":
+ degree_norm = np.diag(1 / (np.diag(degree) + 1e-10))
+ laplacian_norm = degree_norm.dot(laplacian)
+ else:
+ degree_half = np.diag(degree) ** 0.5 + 1e-15
+ laplacian_norm = laplacian / degree_half[:, np.newaxis] / degree_half
+
+ # c) Compute eigenvalues and eigenvectors of L_norm
+ eigenvalues, eigenvectors = np.linalg.eig(laplacian_norm)
+ eigenvalues = eigenvalues.real
+ eigenvectors = eigenvectors.real
+ index_array = np.argsort(eigenvalues)
+ eigenvalues = eigenvalues[index_array]
+ eigenvectors = eigenvectors[:, index_array]
+
+ # d) Compute the number of clusters k
+ k = min_n_clusters
+ for k in range(min_n_clusters, max_n_clusters + 1):
+ if eigenvalues[k] > threshold:
+ break
+ k = max(k, min_n_clusters)
+ spectral_embeddings = eigenvectors[:, :k]
+ # print(mid, k, eigenvalues[:10])
+
+ spectral_embeddings = spectral_embeddings / np.linalg.norm(spectral_embeddings, axis=1, ord=2, keepdims=True)
+ solver = cluster.KMeans(n_clusters=k, max_iter=1000, random_state=42)
+ solver.fit(spectral_embeddings)
+ return solver.labels_
+
+
+if __name__ == "__main__":
+ path = sys.argv[1] # dump2/raw/Eval_Ali_far
+ raw_path = sys.argv[2] # data/local/Eval_Ali_far
+ threshold = float(sys.argv[3]) # 0.996
+ sv_threshold = float(sys.argv[4]) # 0.815
+ wav_scp_file = open(path+'/wav.scp', 'r')
+ wav_scp = wav_scp_file.readlines()
+ wav_scp_file.close()
+ raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp = raw_meeting_scp_file.readlines()
+ raw_meeting_scp_file.close()
+ segments_scp_file = open(raw_path + '/segments', 'r')
+ segments_scp = segments_scp_file.readlines()
+ segments_scp_file.close()
+
+ segments_map = {}
+ for line in segments_scp:
+ line_list = line.strip().split(' ')
+ meeting = line_list[1]
+ seg = (float(line_list[-2]), float(line_list[-1]))
+ if meeting not in segments_map.keys():
+ segments_map[meeting] = [seg]
+ else:
+ segments_map[meeting].append(seg)
+
+ inference_sv_pipline = pipeline(
+ task=Tasks.speaker_verification,
+ model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+ )
+
+ chunk_len = int(1.5*16000) # 1.5 seconds
+ hop_len = int(0.75*16000) # 0.75 seconds
+
+ os.system("mkdir -p " + path + "/cluster_profile_infer")
+ cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
+ meeting_map = {}
+ for line in raw_meeting_scp:
+ meeting = line.strip().split('\t')[0]
+ wav_path = line.strip().split('\t')[1]
+ wav = soundfile.read(wav_path)[0]
+ # take the first channel
+ if wav.ndim == 2:
+ wav=wav[:, 0]
+ # gen_seg_embedding
+ segments_list = segments_map[meeting]
+
+ # import ipdb;ipdb.set_trace()
+ all_seg_embedding_list = []
+ for seg in segments_list:
+ wav_seg = wav[int(seg[0] * 16000): int(seg[1] * 16000)]
+ wav_seg_len = wav_seg.shape[0]
+ i = 0
+ while i < wav_seg_len:
+ if i + chunk_len < wav_seg_len:
+ cur_wav_chunk = wav_seg[i: i+chunk_len]
+ else:
+ cur_wav_chunk=wav_seg[i: ]
+ # chunks under 0.2s are ignored
+ if cur_wav_chunk.shape[0] >= 0.2 * 16000:
+ cur_chunk_embedding = inference_sv_pipline(audio_in=cur_wav_chunk)["spk_embedding"]
+ all_seg_embedding_list.append(cur_chunk_embedding)
+ i += hop_len
+ all_seg_embedding = np.vstack(all_seg_embedding_list)
+ # all_seg_embedding (n, dim)
+
+ # compute affinity
+ affinity=cosine_similarity(all_seg_embedding)
+
+ affinity = np.maximum(affinity - sv_threshold, 0.0001) / (affinity.max() - sv_threshold)
+
+ # clustering
+ labels = custom_spectral_clustering(
+ affinity=affinity,
+ min_n_clusters=2,
+ max_n_clusters=4,
+ refine=True,
+ threshold=threshold,
+ laplacian_type="graph_cut")
+
+
+ cluster_dict={}
+ for j in range(labels.shape[0]):
+ if labels[j] not in cluster_dict.keys():
+ cluster_dict[labels[j]] = np.atleast_2d(all_seg_embedding[j])
+ else:
+ cluster_dict[labels[j]] = np.concatenate((cluster_dict[labels[j]], np.atleast_2d(all_seg_embedding[j])))
+
+ emb_list = []
+ # get cluster center
+ for k in cluster_dict.keys():
+ cluster_dict[k] = np.mean(cluster_dict[k], axis=0)
+ emb_list.append(cluster_dict[k])
+
+ spk_num = len(emb_list)
+ profile_for_infer = np.vstack(emb_list)
+ # save profile for each meeting
+ np.save(path + '/cluster_profile_infer/' + meeting + '.npy', profile_for_infer)
+ meeting_map[meeting] = (path + '/cluster_profile_infer/' + meeting + '.npy', spk_num)
+ cluster_spk_num_file.write(meeting + ' ' + str(spk_num) + '\n')
+ cluster_spk_num_file.flush()
+
+ cluster_spk_num_file.close()
+
+ profile_scp = open(path + "/cluster_profile_infer.scp", 'w')
+ for line in wav_scp:
+ uttid = line.strip().split(' ')[0]
+ meeting = uttid.split('-')[0]
+ profile_scp.write(uttid + ' ' + meeting_map[meeting][0] + '\n')
+ profile_scp.flush()
+ profile_scp.close()
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
new file mode 100644
index 0000000..18286b4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
@@ -0,0 +1,70 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+import sys
+import os
+import soundfile
+
+
+if __name__=="__main__":
+ path = sys.argv[1] # dump2/raw/Eval_Ali_far
+ raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
+ raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp = raw_meeting_scp_file.readlines()
+ raw_meeting_scp_file.close()
+ segments_scp_file = open(raw_path + '/segments', 'r')
+ segments_scp = segments_scp_file.readlines()
+ segments_scp_file.close()
+
+ oracle_emb_dir = path + '/oracle_embedding/'
+ os.system("mkdir -p " + oracle_emb_dir)
+ oracle_emb_scp_file = open(path+'/oracle_embedding.scp', 'w')
+
+ raw_wav_map = {}
+ for line in raw_meeting_scp:
+ meeting = line.strip().split('\t')[0]
+ wav_path = line.strip().split('\t')[1]
+ raw_wav_map[meeting] = wav_path
+
+ spk_map = {}
+ for line in segments_scp:
+ line_list = line.strip().split(' ')
+ meeting = line_list[1]
+ spk_id = line_list[0].split('_')[3]
+ spk = meeting + '_' + spk_id
+ time_start = float(line_list[-2])
+ time_end = float(line_list[-1])
+ if time_end - time_start > 0.5:
+ if spk not in spk_map.keys():
+ spk_map[spk] = [(int(time_start * 16000), int(time_end * 16000))]
+ else:
+ spk_map[spk].append((int(time_start * 16000), int(time_end * 16000)))
+
+ inference_sv_pipline = pipeline(
+ task=Tasks.speaker_verification,
+ model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+ )
+
+ for spk in spk_map.keys():
+ meeting = spk.split('_SPK')[0]
+ wav_path = raw_wav_map[meeting]
+ wav = soundfile.read(wav_path)[0]
+ # take the first channel
+ if wav.ndim == 2:
+ wav = wav[:, 0]
+ all_seg_embedding_list=[]
+ # import ipdb;ipdb.set_trace()
+ for seg_time in spk_map[spk]:
+ if seg_time[0] < wav.shape[0] - 0.5 * 16000:
+ if seg_time[1] > wav.shape[0]:
+ cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: ])["spk_embedding"]
+ else:
+ cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: seg_time[1]])["spk_embedding"]
+ all_seg_embedding_list.append(cur_seg_embedding)
+ all_seg_embedding = np.vstack(all_seg_embedding_list)
+ spk_embedding = np.mean(all_seg_embedding, axis=0)
+ np.save(oracle_emb_dir + spk + '.npy', spk_embedding)
+ oracle_emb_scp_file.write(spk + ' ' + oracle_emb_dir + spk + '.npy' + '\n')
+ oracle_emb_scp_file.flush()
+
+ oracle_emb_scp_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
new file mode 100644
index 0000000..f44fcd4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
@@ -0,0 +1,59 @@
+import random
+import numpy as np
+import os
+import sys
+
+
+if __name__=="__main__":
+ path = sys.argv[1] # dump2/raw/Eval_Ali_far
+ wav_scp_file = open(path+"/wav.scp", 'r')
+ wav_scp = wav_scp_file.readlines()
+ wav_scp_file.close()
+ spk2id_file = open(path + "/spk2id", 'r')
+ spk2id = spk2id_file.readlines()
+ spk2id_file.close()
+ embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
+ embedding_scp = embedding_scp_file.readlines()
+ embedding_scp_file.close()
+
+ embedding_map = {}
+ for line in embedding_scp:
+ spk = line.strip().split(' ')[0]
+ if spk not in embedding_map.keys():
+ emb=np.load(line.strip().split(' ')[1])
+ embedding_map[spk] = emb
+
+ meeting_map_tmp = {}
+ global_spk_list = []
+ for line in spk2id:
+ line_list = line.strip().split(' ')
+ meeting = line_list[0].split('-')[0]
+ spk_id = line_list[0].split('-')[-1].split('_')[-1]
+ spk = meeting + '_' + spk_id
+ global_spk_list.append(spk)
+ if meeting in meeting_map_tmp.keys():
+ meeting_map_tmp[meeting].append(spk)
+ else:
+ meeting_map_tmp[meeting] = [spk]
+
+ meeting_map = {}
+ os.system('mkdir -p ' + path + '/oracle_profile_nopadding')
+ for meeting in meeting_map_tmp.keys():
+ emb_list = []
+ for i in range(len(meeting_map_tmp[meeting])):
+ spk = meeting_map_tmp[meeting][i]
+ emb_list.append(embedding_map[spk])
+ profile = np.vstack(emb_list)
+ np.save(path + '/oracle_profile_nopadding/' + meeting + '.npy', profile)
+ meeting_map[meeting] = path + '/oracle_profile_nopadding/' + meeting + '.npy'
+
+ profile_scp = open(path + '/oracle_profile_nopadding.scp', 'w')
+ profile_map_scp = open(path + '/oracle_profile_nopadding_spk_list', 'w')
+
+ for line in wav_scp:
+ uttid = line.strip().split(' ')[0]
+ meeting = uttid.split('-')[0]
+ profile_scp.write(uttid + ' ' + meeting_map[meeting] + '\n')
+ profile_map_scp.write(uttid + ' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
+ profile_scp.close()
+ profile_map_scp.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
new file mode 100644
index 0000000..b70a32a
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
@@ -0,0 +1,68 @@
+import random
+import numpy as np
+import os
+import sys
+
+
+if __name__=="__main__":
+ path = sys.argv[1] # dump2/raw/Train_Ali_far
+ wav_scp_file = open(path+"/wav.scp", 'r')
+ wav_scp = wav_scp_file.readlines()
+ wav_scp_file.close()
+ spk2id_file = open(path+"/spk2id", 'r')
+ spk2id = spk2id_file.readlines()
+ spk2id_file.close()
+ embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
+ embedding_scp = embedding_scp_file.readlines()
+ embedding_scp_file.close()
+
+ embedding_map = {}
+ for line in embedding_scp:
+ spk = line.strip().split(' ')[0]
+ if spk not in embedding_map.keys():
+ emb = np.load(line.strip().split(' ')[1])
+ embedding_map[spk] = emb
+
+ meeting_map_tmp = {}
+ global_spk_list = []
+ for line in spk2id:
+ line_list = line.strip().split(' ')
+ meeting = line_list[0].split('-')[0]
+ spk_id = line_list[0].split('-')[-1].split('_')[-1]
+ spk = meeting+'_' + spk_id
+ global_spk_list.append(spk)
+ if meeting in meeting_map_tmp.keys():
+ meeting_map_tmp[meeting].append(spk)
+ else:
+ meeting_map_tmp[meeting] = [spk]
+
+ for meeting in meeting_map_tmp.keys():
+ num = len(meeting_map_tmp[meeting])
+ if num < 4:
+ global_spk_list_tmp = global_spk_list[: ]
+ for spk in meeting_map_tmp[meeting]:
+ global_spk_list_tmp.remove(spk)
+ padding_spk = random.sample(global_spk_list_tmp, 4 - num)
+ meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk
+
+ meeting_map = {}
+ os.system('mkdir -p ' + path + '/oracle_profile_padding')
+ for meeting in meeting_map_tmp.keys():
+ emb_list = []
+ for i in range(len(meeting_map_tmp[meeting])):
+ spk = meeting_map_tmp[meeting][i]
+ emb_list.append(embedding_map[spk])
+ profile = np.vstack(emb_list)
+ np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile)
+ meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy'
+
+ profile_scp = open(path + '/oracle_profile_padding.scp', 'w')
+ profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w')
+
+ for line in wav_scp:
+ uttid = line.strip().split(' ')[0]
+ meeting = uttid.split('-')[0]
+ profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n')
+ profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
+ profile_scp.close()
+ profile_map_scp.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/proce_text.py b/egs/alimeeting/sa-asr/local/proce_text.py
new file mode 100755
index 0000000..e56cc0f
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/proce_text.py
@@ -0,0 +1,32 @@
+
+import sys
+import re
+
+in_f = sys.argv[1]
+out_f = sys.argv[2]
+
+
+with open(in_f, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+with open(out_f, "w", encoding="utf-8") as f:
+ for line in lines:
+ outs = line.strip().split(" ", 1)
+ if len(outs) == 2:
+ idx, text = outs
+ text = re.sub("</s>", "", text)
+ text = re.sub("<s>", "", text)
+ text = re.sub("@@", "", text)
+ text = re.sub("@", "", text)
+ text = re.sub("<unk>", "", text)
+ text = re.sub(" ", "", text)
+ text = re.sub("\$", "", text)
+ text = text.lower()
+ else:
+ idx = outs[0]
+ text = " "
+
+ text = [x for x in text]
+ text = " ".join(text)
+ out = "{} {}\n".format(idx, text)
+ f.write(out)
diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
new file mode 100755
index 0000000..d900bb1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ args = parser.parse_args()
+ return args
+
+class Segment(object):
+ def __init__(self, uttid, text):
+ self.uttid = uttid
+ self.text = text
+
+def main(args):
+ text = codecs.open(Path(args.path) / "text", "r", "utf-8")
+ spk2utt = codecs.open(Path(args.path) / "spk2utt", "r", "utf-8")
+ utt2spk = codecs.open(Path(args.path) / "utt2spk_all_fifo", "r", "utf-8")
+ spk2id = codecs.open(Path(args.path) / "spk2id", "w", "utf-8")
+
+ spkid_map = {}
+ meetingid_map = {}
+ for line in spk2utt:
+ spkid = line.strip().split(" ")[0]
+ meeting_id_list = spkid.split("_")[:3]
+ meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
+ if meeting_id not in meetingid_map:
+ meetingid_map[meeting_id] = 1
+ else:
+ meetingid_map[meeting_id] += 1
+ spkid_map[spkid] = meetingid_map[meeting_id]
+ spk2id.write("%s %s\n" % (spkid, meetingid_map[meeting_id]))
+
+ utt2spklist = {}
+ for line in utt2spk:
+ uttid = line.strip().split(" ")[0]
+ spkid = line.strip().split(" ")[1]
+ spklist = spkid.split("$")
+ tmp = []
+ for index in range(len(spklist)):
+ tmp.append(spkid_map[spklist[index]])
+ utt2spklist[uttid] = tmp
+ # parse the textgrid file for each utterance
+ all_segments = []
+ for line in text:
+ uttid = line.strip().split(" ")[0]
+ context = line.strip().split(" ")[1]
+ spklist = utt2spklist[uttid]
+ length_text = len(context)
+ cnt = 0
+ tmp_text = ""
+ for index in range(length_text):
+ if context[index] != "$":
+ tmp_text += str(spklist[cnt])
+ else:
+ tmp_text += "$"
+ cnt += 1
+ tmp_seg = Segment(uttid,tmp_text)
+ all_segments.append(tmp_seg)
+
+ text.close()
+ utt2spk.close()
+ spk2utt.close()
+ spk2id.close()
+
+ text_id = codecs.open(Path(args.path) / "text_id", "w", "utf-8")
+
+ for i in range(len(all_segments)):
+ uttid_tmp = all_segments[i].uttid
+ text_tmp = all_segments[i].text
+
+ text_id.write("%s %s\n" % (uttid_tmp, text_tmp))
+
+ text_id.close()
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa-asr/local/process_text_id.py
new file mode 100644
index 0000000..0a9506e
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_text_id.py
@@ -0,0 +1,24 @@
+import sys
+if __name__=="__main__":
+ path=sys.argv[1]
+
+ text_id_old_file=open(path+"/text_id",'r')
+ text_id_old=text_id_old_file.readlines()
+ text_id_old_file.close()
+
+ text_id=open(path+"/text_id_train",'w')
+ for line in text_id_old:
+ uttid=line.strip().split(' ')[0]
+ old_id=line.strip().split(' ')[1]
+ pre_id='0'
+ new_id_list=[]
+ for i in old_id:
+ if i == '$':
+ new_id_list.append(pre_id)
+ else:
+ new_id_list.append(str(int(i)-1))
+ pre_id=str(int(i)-1)
+ new_id_list.append(pre_id)
+ new_id=' '.join(new_id_list)
+ text_id.write(uttid+' '+new_id+'\n')
+ text_id.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py
new file mode 100644
index 0000000..f15d509
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py
@@ -0,0 +1,55 @@
+import sys
+
+
+if __name__ == "__main__":
+ path=sys.argv[1]
+ text_scp_file = open(path + '/text', 'r')
+ text_scp = text_scp_file.readlines()
+ text_scp_file.close()
+ text_id_scp_file = open(path + '/text_id', 'r')
+ text_id_scp = text_id_scp_file.readlines()
+ text_id_scp_file.close()
+ text_spk_merge_file = open(path + '/text_spk_merge', 'w')
+ assert len(text_scp) == len(text_id_scp)
+
+ meeting_map = {} # {meeting_id: [(start_time, text, text_id), (start_time, text, text_id), ...]}
+ for i in range(len(text_scp)):
+ text_line = text_scp[i].strip().split(' ')
+ text_id_line = text_id_scp[i].strip().split(' ')
+ assert text_line[0] == text_id_line[0]
+ if len(text_line) > 1:
+ uttid = text_line[0]
+ text = text_line[1]
+ text_id = text_id_line[1]
+ meeting_id = uttid.split('-')[0]
+ start_time = int(uttid.split('-')[-2])
+ if meeting_id not in meeting_map:
+ meeting_map[meeting_id] = [(start_time,text,text_id)]
+ else:
+ meeting_map[meeting_id].append((start_time,text,text_id))
+
+ for meeting_id in sorted(meeting_map.keys()):
+ cur_meeting_list = sorted(meeting_map[meeting_id], key=lambda x: x[0])
+ text_spk_merge_map = {} #{1: text1, 2: text2, ...}
+ for cur_utt in cur_meeting_list:
+ cur_text = cur_utt[1]
+ cur_text_id = cur_utt[2]
+ assert len(cur_text)==len(cur_text_id)
+ if len(cur_text) != 0:
+ cur_text_split = cur_text.split('$')
+ cur_text_id_split = cur_text_id.split('$')
+ assert len(cur_text_split) == len(cur_text_id_split)
+ for i in range(len(cur_text_split)):
+ if len(cur_text_split[i]) != 0:
+ spk_id = int(cur_text_id_split[i][0])
+ if spk_id not in text_spk_merge_map.keys():
+ text_spk_merge_map[spk_id] = cur_text_split[i]
+ else:
+ text_spk_merge_map[spk_id] += cur_text_split[i]
+ text_spk_merge_list = []
+ for spk_id in sorted(text_spk_merge_map.keys()):
+ text_spk_merge_list.append(text_spk_merge_map[spk_id])
+ text_spk_merge_file.write(meeting_id + ' ' + '$'.join(text_spk_merge_list) + '\n')
+ text_spk_merge_file.flush()
+
+ text_spk_merge_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
new file mode 100755
index 0000000..fdf2460
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+import numpy as np
+import sys
+import math
+
+
+class Segment(object):
+ def __init__(self, uttid, spkr, stime, etime, text):
+ self.uttid = uttid
+ self.spkr = spkr
+ self.stime = round(stime, 2)
+ self.etime = round(etime, 2)
+ self.text = text
+
+ def change_stime(self, time):
+ self.stime = time
+
+ def change_etime(self, time):
+ self.etime = time
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ args = parser.parse_args()
+ return args
+
+
+
+def main(args):
+ textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+ segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8")
+ utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8")
+
+ # get the path of textgrid file for each utterance
+ for line in textgrid_flist:
+ line_array = line.strip().split(" ")
+ path = Path(line_array[1])
+ uttid = line_array[0]
+
+ try:
+ tg = textgrid.TextGrid.fromFile(path)
+ except:
+ pdb.set_trace()
+ num_spk = tg.__len__()
+ spk2textgrid = {}
+ spk2weight = {}
+ weight2spk = {}
+ cnt = 2
+ xmax = 0
+ for i in range(tg.__len__()):
+ spk_name = tg[i].name
+ if spk_name not in spk2weight:
+ spk2weight[spk_name] = cnt
+ weight2spk[cnt] = spk_name
+ cnt = cnt * 2
+ segments = []
+ for j in range(tg[i].__len__()):
+ if tg[i][j].mark:
+ if xmax < tg[i][j].maxTime:
+ xmax = tg[i][j].maxTime
+ segments.append(
+ Segment(
+ uttid,
+ tg[i].name,
+ tg[i][j].minTime,
+ tg[i][j].maxTime,
+ tg[i][j].mark.strip(),
+ )
+ )
+ segments = sorted(segments, key=lambda x: x.stime)
+ spk2textgrid[spk_name] = segments
+ olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32)
+ for spkid in spk2weight.keys():
+ weight = spk2weight[spkid]
+ segments = spk2textgrid[spkid]
+ idx = int(math.log2(weight) )- 1
+ for i in range(len(segments)):
+ stime = segments[i].stime
+ etime = segments[i].etime
+ olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight
+ sum_label = olp_label.sum(axis=0)
+ stime = 0
+ pre_value = 0
+ for pos in range(sum_label.shape[0]):
+ if sum_label[pos] in weight2spk:
+ if pre_value in weight2spk:
+ if sum_label[pos] != pre_value:
+ spkids = weight2spk[pre_value]
+ spkid_array = spkids.split("_")
+ spkid = spkid_array[-1]
+ #spkid = uttid+spkid
+ if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
+ segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
+ utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
+ stime = pos
+ pre_value = sum_label[pos]
+ else:
+ stime = pos
+ pre_value = sum_label[pos]
+ else:
+ if pre_value in weight2spk:
+ spkids = weight2spk[pre_value]
+ spkid_array = spkids.split("_")
+ spkid = spkid_array[-1]
+ #spkid = uttid+spkid
+ if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
+ segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
+ utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
+ stime = pos
+ pre_value = sum_label[pos]
+ textgrid_flist.close()
+ segment_file.close()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa-asr/local/text_format.pl
new file mode 100755
index 0000000..45f1f64
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/text_format.pl
@@ -0,0 +1,14 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright Chao Weng
+
+# normalizations for hkust trascript
+# see the docs/trans-guidelines.pdf for details
+
+while (<STDIN>) {
+ @A = split(" ", $_);
+ if (@A == 1) {
+ next;
+ }
+ print $_
+}
diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa-asr/local/text_normalize.pl
new file mode 100755
index 0000000..ac301d4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/text_normalize.pl
@@ -0,0 +1,38 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright Chao Weng
+
+# normalizations for hkust trascript
+# see the docs/trans-guidelines.pdf for details
+
+while (<STDIN>) {
+ @A = split(" ", $_);
+ print "$A[0] ";
+ for ($n = 1; $n < @A; $n++) {
+ $tmp = $A[$n];
+ if ($tmp =~ /<sil>/) {$tmp =~ s:<sil>::g;}
+ if ($tmp =~ /<%>/) {$tmp =~ s:<%>::g;}
+ if ($tmp =~ /<->/) {$tmp =~ s:<->::g;}
+ if ($tmp =~ /<\$>/) {$tmp =~ s:<\$>::g;}
+ if ($tmp =~ /<#>/) {$tmp =~ s:<#>::g;}
+ if ($tmp =~ /<_>/) {$tmp =~ s:<_>::g;}
+ if ($tmp =~ /<space>/) {$tmp =~ s:<space>::g;}
+ if ($tmp =~ /`/) {$tmp =~ s:`::g;}
+ if ($tmp =~ /&/) {$tmp =~ s:&::g;}
+ if ($tmp =~ /,/) {$tmp =~ s:,::g;}
+ if ($tmp =~ /[a-zA-Z]/) {$tmp=uc($tmp);}
+ if ($tmp =~ /锛�/) {$tmp =~ s:锛�:A:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:A:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:B:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:C:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:K:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:T:g;}
+ if ($tmp =~ /锛�/) {$tmp =~ s:锛�::g;}
+ if ($tmp =~ /涓�/) {$tmp =~ s:涓�::g;}
+ if ($tmp =~ /銆�/) {$tmp =~ s:銆�::g;}
+ if ($tmp =~ /銆�/) {$tmp =~ s:銆�::g;}
+ if ($tmp =~ /锛�/) {$tmp =~ s:锛�::g;}
+ print "$tmp ";
+ }
+ print "\n";
+}
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh
new file mode 100755
index 0000000..3aa13d0
--- /dev/null
+++ b/egs/alimeeting/sa-asr/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
+export PATH=$PWD/utils/:$PATH
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
new file mode 100755
index 0000000..1fd63d6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from io import BytesIO
+from pathlib import Path
+from typing import Tuple, Optional
+
+import kaldiio
+import humanfriendly
+import numpy as np
+import resampy
+import soundfile
+from tqdm import tqdm
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.read_text import read_2column_text
+from funasr.fileio.sound_scp import SoundScpWriter
+
+
+def humanfriendly_or_none(value: str):
+ if value in ("none", "None", "NONE"):
+ return None
+ return humanfriendly.parse_size(value)
+
+
+def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
+ """
+
+ >>> str2int_tuple('3,4,5')
+ (3, 4, 5)
+
+ """
+ assert check_argument_types()
+ if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
+ return None
+ return tuple(map(int, integers.strip().split(",")))
+
+
+def main():
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=logfmt)
+ logging.info(get_commandline_args())
+
+ parser = argparse.ArgumentParser(
+ description='Create waves list from "wav.scp"',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument("scp")
+ parser.add_argument("outdir")
+ parser.add_argument(
+ "--name",
+ default="wav",
+ help="Specify the prefix word of output file name " 'such as "wav.scp"',
+ )
+ parser.add_argument("--segments", default=None)
+ parser.add_argument(
+ "--fs",
+ type=humanfriendly_or_none,
+ default=None,
+ help="If the sampling rate specified, " "Change the sampling rate.",
+ )
+ parser.add_argument("--audio-format", default="wav")
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument("--ref-channels", default=None, type=str2int_tuple)
+ group.add_argument("--utt2ref-channels", default=None, type=str)
+ args = parser.parse_args()
+
+ out_num_samples = Path(args.outdir) / f"utt2num_samples"
+
+ if args.ref_channels is not None:
+
+ def utt2ref_channels(x) -> Tuple[int, ...]:
+ return args.ref_channels
+
+ elif args.utt2ref_channels is not None:
+ utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
+
+ def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
+ chs_str = d[x]
+ return tuple(map(int, chs_str.split()))
+
+ else:
+ utt2ref_channels = None
+
+ Path(args.outdir).mkdir(parents=True, exist_ok=True)
+ out_wavscp = Path(args.outdir) / f"{args.name}.scp"
+ if args.segments is not None:
+ # Note: kaldiio supports only wav-pcm-int16le file.
+ loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
+ if args.audio_format.endswith("ark"):
+ fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+ fscp = out_wavscp.open("w")
+ else:
+ writer = SoundScpWriter(
+ args.outdir,
+ out_wavscp,
+ format=args.audio_format,
+ )
+
+ with out_num_samples.open("w") as fnum_samples:
+ for uttid, (rate, wave) in tqdm(loader):
+ # wave: (Time,) or (Time, Nmic)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is incompatible with Kaldi
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fscp,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+
+ else:
+ writer[uttid] = rate, wave
+ fnum_samples.write(f"{uttid} {len(wave)}\n")
+ else:
+ if args.audio_format.endswith("ark"):
+ fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+ else:
+ wavdir = Path(args.outdir) / f"data_{args.name}"
+ wavdir.mkdir(parents=True, exist_ok=True)
+
+ with Path(args.scp).open("r") as fscp, out_wavscp.open(
+ "w"
+ ) as fout, out_num_samples.open("w") as fnum_samples:
+ for line in tqdm(fscp):
+ uttid, wavpath = line.strip().split(None, 1)
+
+ if wavpath.endswith("|"):
+ # Streaming input e.g. cat a.wav |
+ with kaldiio.open_like_kaldi(wavpath, "rb") as f:
+ with BytesIO(f.read()) as g:
+ wave, rate = soundfile.read(g, dtype=np.int16)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is incompatible with Kaldi
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fout,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+ else:
+ owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+ soundfile.write(owavpath, wave, rate)
+ fout.write(f"{uttid} {owavpath}\n")
+ else:
+ wave, rate = soundfile.read(wavpath, dtype=np.int16)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+ save_asis = False
+
+ elif args.audio_format.endswith("ark"):
+ save_asis = False
+
+ elif Path(wavpath).suffix == "." + args.audio_format and (
+ args.fs is None or args.fs == rate
+ ):
+ save_asis = True
+
+ else:
+ save_asis = False
+
+ if save_asis:
+ # Neither --segments nor --fs are specified and
+ # the line doesn't end with "|",
+ # i.e. not using unix-pipe,
+ # only in this case,
+ # just using the original file as is.
+ fout.write(f"{uttid} {wavpath}\n")
+ else:
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is not supported in Kaldi.
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fout,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+ else:
+ owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+ soundfile.write(owavpath, wave, rate)
+ fout.write(f"{uttid} {owavpath}\n")
+ fnum_samples.write(f"{uttid} {len(wave)}\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
new file mode 100755
index 0000000..b0c61e5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+import sys
+
+
+def get_commandline_args(no_executable=True):
+ extra_chars = [
+ " ",
+ ";",
+ "&",
+ "|",
+ "<",
+ ">",
+ "?",
+ "*",
+ "~",
+ "`",
+ '"',
+ "'",
+ "\\",
+ "{",
+ "}",
+ "(",
+ ")",
+ ]
+
+ # Escape the extra characters for shell
+ argv = [
+ arg.replace("'", "'\\''")
+ if all(char not in arg for char in extra_chars)
+ else "'" + arg.replace("'", "'\\''") + "'"
+ for arg in sys.argv
+ ]
+
+ if no_executable:
+ return " ".join(argv[1:])
+ else:
+ return sys.executable + " " + " ".join(argv)
+
+
+def main():
+ print(get_commandline_args())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023.sh b/egs/alimeeting/sa-asr/run_m2met_2023.sh
new file mode 100755
index 0000000..807e499
--- /dev/null
+++ b/egs/alimeeting/sa-asr/run_m2met_2023.sh
@@ -0,0 +1,51 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+ngpu=4
+device="0,1,2,3"
+
+#stage 1 creat both near and far
+stage=1
+stop_stage=18
+
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_Ali_far"
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+inference_config=conf/decode_asr_rnn.yaml
+
+lm_config=conf/train_lm_transformer.yaml
+use_lm=false
+use_wordlm=false
+./asr_local.sh \
+ --device ${device} \
+ --ngpu ${ngpu} \
+ --stage ${stage} \
+ --stop_stage ${stop_stage} \
+ --gpu_inference true \
+ --njob_infer 4 \
+ --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
+ --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
+ --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
+ --lang zh \
+ --audio_format wav \
+ --feats_type raw \
+ --token_type char \
+ --use_lm ${use_lm} \
+ --use_word_lm ${use_wordlm} \
+ --lm_config "${lm_config}" \
+ --asr_config "${asr_config}" \
+ --sa_asr_config "${sa_asr_config}" \
+ --inference_config "${inference_config}" \
+ --train_set "${train_set}" \
+ --valid_set "${valid_set}" \
+ --test_sets "${test_sets}" \
+ --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
new file mode 100755
index 0000000..d35e6a6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
@@ -0,0 +1,50 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+ngpu=4
+device="0,1,2,3"
+
+stage=1
+stop_stage=4
+
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_2023_Ali_far"
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+inference_config=conf/decode_asr_rnn.yaml
+
+lm_config=conf/train_lm_transformer.yaml
+use_lm=false
+use_wordlm=false
+./asr_local_infer.sh \
+ --device ${device} \
+ --ngpu ${ngpu} \
+ --stage ${stage} \
+ --stop_stage ${stop_stage} \
+ --gpu_inference true \
+ --njob_infer 4 \
+ --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
+ --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
+ --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
+ --lang zh \
+ --audio_format wav \
+ --feats_type raw \
+ --token_type char \
+ --use_lm ${use_lm} \
+ --use_word_lm ${use_wordlm} \
+ --lm_config "${lm_config}" \
+ --asr_config "${asr_config}" \
+ --sa_asr_config "${sa_asr_config}" \
+ --inference_config "${inference_config}" \
+ --train_set "${train_set}" \
+ --valid_set "${valid_set}" \
+ --test_sets "${test_sets}" \
+ --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
new file mode 100755
index 0000000..15e4563
--- /dev/null
+++ b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
@@ -0,0 +1,142 @@
+#!/usr/bin/env bash
+set -euo pipefail
+SECONDS=0
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+help_message=$(cat << EOF
+Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
+e.g.
+$0 data/test/wav.scp data/test_format/
+
+Format 'wav.scp': In short words,
+changing "kaldi-datadir" to "modified-kaldi-datadir"
+
+The 'wav.scp' format in kaldi is very flexible,
+e.g. It can use unix-pipe as describing that wav file,
+but it sometime looks confusing and make scripts more complex.
+This tools creates actual wav files from 'wav.scp'
+and also segments wav files using 'segments'.
+
+Options
+ --fs <fs>
+ --segments <segments>
+ --nj <nj>
+ --cmd <cmd>
+EOF
+)
+
+out_filename=wav.scp
+cmd=utils/run.pl
+nj=30
+fs=none
+segments=
+
+ref_channels=
+utt2ref_channels=
+
+audio_format=wav
+write_utt2num_samples=true
+
+log "$0 $*"
+. utils/parse_options.sh
+
+if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
+ log "${help_message}"
+ log "Error: invalid command line arguments"
+ exit 1
+fi
+
+. ./path.sh # Setup the environment
+
+scp=$1
+if [ ! -f "${scp}" ]; then
+ log "${help_message}"
+ echo "$0: Error: No such file: ${scp}"
+ exit 1
+fi
+dir=$2
+
+
+if [ $# -eq 2 ]; then
+ logdir=${dir}/logs
+ outdir=${dir}/data
+
+elif [ $# -eq 3 ]; then
+ logdir=$3
+ outdir=${dir}/data
+
+elif [ $# -eq 4 ]; then
+ logdir=$3
+ outdir=$4
+fi
+
+
+mkdir -p ${logdir}
+
+rm -f "${dir}/${out_filename}"
+
+
+opts=
+if [ -n "${utt2ref_channels}" ]; then
+ opts="--utt2ref-channels ${utt2ref_channels} "
+elif [ -n "${ref_channels}" ]; then
+ opts="--ref-channels ${ref_channels} "
+fi
+
+
+if [ -n "${segments}" ]; then
+ log "[info]: using ${segments}"
+ nutt=$(<${segments} wc -l)
+ nj=$((nj<nutt?nj:nutt))
+
+ split_segments=""
+ for n in $(seq ${nj}); do
+ split_segments="${split_segments} ${logdir}/segments.${n}"
+ done
+
+ utils/split_scp.pl "${segments}" ${split_segments}
+
+ ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+ pyscripts/audio/format_wav_scp.py \
+ ${opts} \
+ --fs ${fs} \
+ --audio-format "${audio_format}" \
+ "--segment=${logdir}/segments.JOB" \
+ "${scp}" "${outdir}/format.JOB"
+
+else
+ log "[info]: without segments"
+ nutt=$(<${scp} wc -l)
+ nj=$((nj<nutt?nj:nutt))
+
+ split_scps=""
+ for n in $(seq ${nj}); do
+ split_scps="${split_scps} ${logdir}/wav.${n}.scp"
+ done
+
+ utils/split_scp.pl "${scp}" ${split_scps}
+ ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+ pyscripts/audio/format_wav_scp.py \
+ ${opts} \
+ --fs "${fs}" \
+ --audio-format "${audio_format}" \
+ "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
+fi
+
+# Workaround for the NFS problem
+ls ${outdir}/format.* > /dev/null
+
+# concatenate the .scp files together.
+for n in $(seq ${nj}); do
+ cat "${outdir}/format.${n}/wav.scp" || exit 1;
+done > "${dir}/${out_filename}" || exit 1
+
+if "${write_utt2num_samples}"; then
+ for n in $(seq ${nj}); do
+ cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
+ done > "${dir}/utt2num_samples" || exit 1
+fi
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
new file mode 100755
index 0000000..9e08dba
--- /dev/null
+++ b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
@@ -0,0 +1,116 @@
+#!/usr/bin/env bash
+
+# 2020 @kamo-naoyuki
+# This file was copied from Kaldi and
+# I deleted parts related to wav duration
+# because we shouldn't use kaldi's command here
+# and we don't need the files actually.
+
+# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
+# 2014 Tom Ko
+# 2018 Emotech LTD (author: Pawel Swietojanski)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+# wav.scp
+# spk2utt
+# utt2spk
+# text
+#
+# It generates the files which are used for perturbing the speed of the original data.
+
+export LC_ALL=C
+set -euo pipefail
+
+if [[ $# != 3 ]]; then
+ echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
+ echo "e.g.:"
+ echo " $0 0.9 data/train_si284 data/train_si284p"
+ exit 1
+fi
+
+factor=$1
+srcdir=$2
+destdir=$3
+label="sp"
+spk_prefix="${label}${factor}-"
+utt_prefix="${label}${factor}-"
+
+#check is sox on the path
+
+! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
+
+if [[ ! -f ${srcdir}/utt2spk ]]; then
+ echo "$0: no such file ${srcdir}/utt2spk"
+ exit 1;
+fi
+
+if [[ ${destdir} == "${srcdir}" ]]; then
+ echo "$0: this script requires <srcdir> and <destdir> to be different."
+ exit 1
+fi
+
+mkdir -p "${destdir}"
+
+<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
+<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
+<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
+if [[ ! -f ${srcdir}/utt2uniq ]]; then
+ <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
+else
+ <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
+fi
+
+
+<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
+ utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
+
+utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
+
+if [[ -f ${srcdir}/segments ]]; then
+
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
+ utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
+ awk -v factor="${factor}" \
+ '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
+ >"${destdir}"/segments
+
+ utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+ awk -v factor="${factor}" \
+ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+ else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+ else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+ > "${destdir}"/wav.scp
+ if [[ -f ${srcdir}/reco2file_and_channel ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/reco_map \
+ <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
+ fi
+
+else # no segments->wav indexed by utterance.
+ if [[ -f ${srcdir}/wav.scp ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+ awk -v factor="${factor}" \
+ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+ else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+ else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+ > "${destdir}"/wav.scp
+ fi
+fi
+
+if [[ -f ${srcdir}/text ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
+fi
+if [[ -f ${srcdir}/spk2gender ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
+fi
+if [[ -f ${srcdir}/utt2lang ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
+fi
+
+rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
+echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
+
+utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
diff --git a/egs/alimeeting/sa-asr/utils/apply_map.pl b/egs/alimeeting/sa-asr/utils/apply_map.pl
new file mode 100755
index 0000000..725d346
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/apply_map.pl
@@ -0,0 +1,97 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
+# Apache 2.0.
+
+# This program is a bit like ./sym2int.pl in that it applies a map
+# to things in a file, but it's a bit more general in that it doesn't
+# assume the things being mapped to are single tokens, they could
+# be sequences of tokens. See the usage message.
+
+
+$permissive = 0;
+
+for ($x = 0; $x <= 2; $x++) {
+
+ if (@ARGV > 0 && $ARGV[0] eq "-f") {
+ shift @ARGV;
+ $field_spec = shift @ARGV;
+ if ($field_spec =~ m/^\d+$/) {
+ $field_begin = $field_spec - 1; $field_end = $field_spec - 1;
+ }
+ if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
+ if ($1 ne "") {
+ $field_begin = $1 - 1; # Change to zero-based indexing.
+ }
+ if ($2 ne "") {
+ $field_end = $2 - 1; # Change to zero-based indexing.
+ }
+ }
+ if (!defined $field_begin && !defined $field_end) {
+ die "Bad argument to -f option: $field_spec";
+ }
+ }
+
+ if (@ARGV > 0 && $ARGV[0] eq '--permissive') {
+ shift @ARGV;
+ # Mapping is optional (missing key is printed to output)
+ $permissive = 1;
+ }
+}
+
+if(@ARGV != 1) {
+ print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n";
+ print STDERR <<'EOF';
+Usage: apply_map.pl [options] map <input >output
+ options: [-f <field-range> ] [--permissive]
+ This applies a map to some specified fields of some input text:
+ For each line in the map file: the first field is the thing we
+ map from, and the remaining fields are the sequence we map it to.
+ The -f (field-range) option says which fields of the input file the map
+ map should apply to.
+ If the --permissive option is supplied, fields which are not present
+ in the map will be left as they were.
+ Applies the map 'map' to all input text, where each line of the map
+ is interpreted as a map from the first field to the list of the other fields
+ Note: <field-range> can look like 4-5, or 4-, or 5-, or 1, it means the field
+ range in the input to apply the map to.
+ e.g.: echo A B | apply_map.pl a.txt
+ where a.txt is:
+ A a1 a2
+ B b
+ will produce:
+ a1 a2 b
+EOF
+ exit(1);
+}
+
+($map_file) = @ARGV;
+open(M, "<$map_file") || die "Error opening map file $map_file: $!";
+
+while (<M>) {
+ @A = split(" ", $_);
+ @A >= 1 || die "apply_map.pl: empty line.";
+ $i = shift @A;
+ $o = join(" ", @A);
+ $map{$i} = $o;
+}
+
+while(<STDIN>) {
+ @A = split(" ", $_);
+ for ($x = 0; $x < @A; $x++) {
+ if ( (!defined $field_begin || $x >= $field_begin)
+ && (!defined $field_end || $x <= $field_end)) {
+ $a = $A[$x];
+ if (!defined $map{$a}) {
+ if (!$permissive) {
+ die "apply_map.pl: undefined key $a in $map_file\n";
+ } else {
+ print STDERR "apply_map.pl: warning! missing key $a in $map_file\n";
+ }
+ } else {
+ $A[$x] = $map{$a};
+ }
+ }
+ }
+ print join(" ", @A) . "\n";
+}
diff --git a/egs/alimeeting/sa-asr/utils/combine_data.sh b/egs/alimeeting/sa-asr/utils/combine_data.sh
new file mode 100755
index 0000000..e1eba85
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/combine_data.sh
@@ -0,0 +1,146 @@
+#!/usr/bin/env bash
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
+# 2014 David Snyder
+
+# This script combines the data from multiple source directories into
+# a single destination directory.
+
+# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
+# about what these directories contain.
+
+# Begin configuration section.
+extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
+skip_fix=false # skip the fix_data_dir.sh in the end
+# End configuration section.
+
+echo "$0 $@" # Print the command line for logging
+
+if [ -f path.sh ]; then . ./path.sh; fi
+. parse_options.sh || exit 1;
+
+if [ $# -lt 2 ]; then
+ echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
+ echo "Note, files that don't appear in all source dirs will not be combined,"
+ echo "with the exception of utt2uniq and segments, which are created where necessary."
+ exit 1
+fi
+
+dest=$1;
+shift;
+
+first_src=$1;
+
+rm -r $dest 2>/dev/null || true
+mkdir -p $dest;
+
+export LC_ALL=C
+
+for dir in $*; do
+ if [ ! -f $dir/utt2spk ]; then
+ echo "$0: no such file $dir/utt2spk"
+ exit 1;
+ fi
+done
+
+# Check that frame_shift are compatible, where present together with features.
+dir_with_frame_shift=
+for dir in $*; do
+ if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
+ if [[ $dir_with_frame_shift ]] &&
+ ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
+ echo "$0:error: different frame_shift in directories $dir and " \
+ "$dir_with_frame_shift. Cannot combine features."
+ exit 1;
+ fi
+ dir_with_frame_shift=$dir
+ fi
+done
+
+# W.r.t. utt2uniq file the script has different behavior compared to other files
+# it is not compulsary for it to exist in src directories, but if it exists in
+# even one it should exist in all. We will create the files where necessary
+has_utt2uniq=false
+for in_dir in $*; do
+ if [ -f $in_dir/utt2uniq ]; then
+ has_utt2uniq=true
+ break
+ fi
+done
+
+if $has_utt2uniq; then
+ # we are going to create an utt2uniq file in the destdir
+ for in_dir in $*; do
+ if [ ! -f $in_dir/utt2uniq ]; then
+ # we assume that utt2uniq is a one to one mapping
+ cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
+ else
+ cat $in_dir/utt2uniq
+ fi
+ done | sort -k1 > $dest/utt2uniq
+ echo "$0: combined utt2uniq"
+else
+ echo "$0 [info]: not combining utt2uniq as it does not exist"
+fi
+# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
+extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
+
+# segments are treated similarly to utt2uniq. If it exists in some, but not all
+# src directories, then we generate segments where necessary.
+has_segments=false
+for in_dir in $*; do
+ if [ -f $in_dir/segments ]; then
+ has_segments=true
+ break
+ fi
+done
+
+if $has_segments; then
+ for in_dir in $*; do
+ if [ ! -f $in_dir/segments ]; then
+ echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
+ utils/data/get_segments_for_data.sh $in_dir
+ else
+ cat $in_dir/segments
+ fi
+ done | sort -k1 > $dest/segments
+ echo "$0: combined segments"
+else
+ echo "$0 [info]: not combining segments as it does not exist"
+fi
+
+for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
+ exists_somewhere=false
+ absent_somewhere=false
+ for d in $*; do
+ if [ -f $d/$file ]; then
+ exists_somewhere=true
+ else
+ absent_somewhere=true
+ fi
+ done
+
+ if ! $absent_somewhere; then
+ set -o pipefail
+ ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
+ set +o pipefail
+ echo "$0: combined $file"
+ else
+ if ! $exists_somewhere; then
+ echo "$0 [info]: not combining $file as it does not exist"
+ else
+ echo "$0 [info]: **not combining $file as it does not exist everywhere**"
+ fi
+ fi
+done
+
+utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
+
+if [[ $dir_with_frame_shift ]]; then
+ cp $dir_with_frame_shift/frame_shift $dest
+fi
+
+if ! $skip_fix ; then
+ utils/fix_data_dir.sh $dest || exit 1;
+fi
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh
new file mode 100755
index 0000000..9fd420c
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh
@@ -0,0 +1,145 @@
+#!/usr/bin/env bash
+
+# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+# feats.scp
+# wav.scp
+# vad.scp
+# spk2utt
+# utt2spk
+# text
+#
+# It copies to another directory, possibly adding a specified prefix or a suffix
+# to the utterance and/or speaker names. Note, the recording-ids stay the same.
+#
+
+
+# begin configuration section
+spk_prefix=
+utt_prefix=
+spk_suffix=
+utt_suffix=
+validate_opts= # should rarely be needed.
+# end configuration section
+
+. utils/parse_options.sh
+
+if [ $# != 2 ]; then
+ echo "Usage: "
+ echo " $0 [options] <srcdir> <destdir>"
+ echo "e.g.:"
+ echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1"
+ echo "Options"
+ echo " --spk-prefix=<prefix> # Prefix for speaker ids, default empty"
+ echo " --utt-prefix=<prefix> # Prefix for utterance ids, default empty"
+ echo " --spk-suffix=<suffix> # Suffix for speaker ids, default empty"
+ echo " --utt-suffix=<suffix> # Suffix for utterance ids, default empty"
+ exit 1;
+fi
+
+
+export LC_ALL=C
+
+srcdir=$1
+destdir=$2
+
+if [ ! -f $srcdir/utt2spk ]; then
+ echo "copy_data_dir.sh: no such file $srcdir/utt2spk"
+ exit 1;
+fi
+
+if [ "$destdir" == "$srcdir" ]; then
+ echo "$0: this script requires <srcdir> and <destdir> to be different."
+ exit 1
+fi
+
+set -e;
+
+mkdir -p $destdir
+
+cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map
+cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map
+
+if [ ! -f $srcdir/utt2uniq ]; then
+ if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then
+ cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq
+ fi
+else
+ cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
+fi
+
+cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \
+ utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
+
+utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
+
+if [ -f $srcdir/feats.scp ]; then
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
+fi
+
+if [ -f $srcdir/vad.scp ]; then
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
+fi
+
+if [ -f $srcdir/segments ]; then
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
+ cp $srcdir/wav.scp $destdir
+else # no segments->wav indexed by utt.
+ if [ -f $srcdir/wav.scp ]; then
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
+ fi
+fi
+
+if [ -f $srcdir/reco2file_and_channel ]; then
+ cp $srcdir/reco2file_and_channel $destdir/
+fi
+
+if [ -f $srcdir/text ]; then
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
+fi
+if [ -f $srcdir/utt2dur ]; then
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
+fi
+if [ -f $srcdir/utt2num_frames ]; then
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
+fi
+if [ -f $srcdir/reco2dur ]; then
+ if [ -f $srcdir/segments ]; then
+ cp $srcdir/reco2dur $destdir/reco2dur
+ else
+ utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
+ fi
+fi
+if [ -f $srcdir/spk2gender ]; then
+ utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
+fi
+if [ -f $srcdir/cmvn.scp ]; then
+ utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
+fi
+for f in frame_shift stm glm ctm; do
+ if [ -f $srcdir/$f ]; then
+ cp $srcdir/$f $destdir
+ fi
+done
+
+rm $destdir/spk_map $destdir/utt_map
+
+echo "$0: copied data from $srcdir to $destdir"
+
+for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do
+ if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then
+ echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to"
+ echo " ... $destdir/.backup/$f"
+ mkdir -p $destdir/.backup
+ mv $destdir/$f $destdir/.backup/
+ fi
+done
+
+
+[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
+[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
+
+utils/validate_data_dir.sh $validate_opts $destdir
diff --git a/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh
new file mode 100755
index 0000000..24f51e7
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh
@@ -0,0 +1,143 @@
+#!/usr/bin/env bash
+
+# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
+# 2018 Andrea Carmantini
+# Apache 2.0
+
+# This script operates on a data directory, such as in data/train/, and adds the
+# reco2dur file if it does not already exist. The file 'reco2dur' maps from
+# recording to the duration of the recording in seconds. This script works it
+# out from the 'wav.scp' file, or, if utterance-ids are the same as recording-ids, from the
+# utt2dur file (it first tries interrogating the headers, and if this fails, it reads the wave
+# files in entirely.)
+# We could use durations from segments file, but that's not the duration of the recordings
+# but the sum of utterance lenghts (silence in between could be excluded from segments)
+# For sum of utterance lenghts:
+# awk 'FNR==NR{uttdur[$1]=$2;next}
+# { for(i=2;i<=NF;i++){dur+=uttdur[$i];}
+# print $1 FS dur; dur=0 }' $data/utt2dur $data/reco2utt
+
+
+frame_shift=0.01
+cmd=run.pl
+nj=4
+
+. utils/parse_options.sh
+. ./path.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: $0 [options] <datadir>"
+ echo "e.g.:"
+ echo " $0 data/train"
+ echo " Options:"
+ echo " --frame-shift # frame shift in seconds. Only relevant when we are"
+ echo " # getting duration from feats.scp (default: 0.01). "
+ exit 1
+fi
+
+export LC_ALL=C
+
+data=$1
+
+
+if [ -s $data/reco2dur ] && \
+ [ $(wc -l < $data/wav.scp) -eq $(wc -l < $data/reco2dur) ]; then
+ echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it."
+ exit 0;
+fi
+
+if [ -s $data/utt2dur ] && \
+ [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ] && \
+ [ ! -s $data/segments ]; then
+
+ echo "$0: $data/wav.scp indexed by utt-id; copying utt2dur to reco2dur"
+ cp $data/utt2dur $data/reco2dur && exit 0;
+
+elif [ -f $data/wav.scp ]; then
+ echo "$0: obtaining durations from recordings"
+
+ # if the wav.scp contains only lines of the form
+ # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
+ if cat $data/wav.scp | perl -e '
+ while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
+ @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
+ $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
+ $reco = $A[0]; $sphere_file = $A[4];
+
+ if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
+ $sample_rate = -1; $sample_count = -1;
+ for ($n = 0; $n <= 30; $n++) {
+ $line = <F>;
+ if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
+ if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
+ if ($line =~ m/end_head/) { break; }
+ }
+ close(F);
+ if ($sample_rate == -1 || $sample_count == -1) {
+ die "could not parse sphere header from $sphere_file";
+ }
+ $duration = $sample_count * 1.0 / $sample_rate;
+ print "$reco $duration\n";
+ } ' > $data/reco2dur; then
+ echo "$0: successfully obtained recording lengths from sphere-file headers"
+ else
+ echo "$0: could not get recording lengths from sphere-file headers, using wav-to-duration"
+ if ! command -v wav-to-duration >/dev/null; then
+ echo "$0: wav-to-duration is not on your path"
+ exit 1;
+ fi
+
+ read_entire_file=false
+ if grep -q 'sox.*speed' $data/wav.scp; then
+ read_entire_file=true
+ echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
+ echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
+ echo "... perturb_data_dir_speed_3way.sh."
+ fi
+
+ num_recos=$(wc -l <$data/wav.scp)
+ if [ $nj -gt $num_recos ]; then
+ nj=$num_recos
+ fi
+
+ temp_data_dir=$data/wav${nj}split
+ wavscps=$(for n in `seq $nj`; do echo $temp_data_dir/$n/wav.scp; done)
+ subdirs=$(for n in `seq $nj`; do echo $temp_data_dir/$n; done)
+
+ if ! mkdir -p $subdirs >&/dev/null; then
+ for n in `seq $nj`; do
+ mkdir -p $temp_data_dir/$n
+ done
+ fi
+
+ utils/split_scp.pl $data/wav.scp $wavscps
+
+
+ $cmd JOB=1:$nj $data/log/get_reco_durations.JOB.log \
+ wav-to-duration --read-entire-file=$read_entire_file \
+ scp:$temp_data_dir/JOB/wav.scp ark,t:$temp_data_dir/JOB/reco2dur || \
+ { echo "$0: there was a problem getting the durations"; exit 1; } # This could
+
+ for n in `seq $nj`; do
+ cat $temp_data_dir/$n/reco2dur
+ done > $data/reco2dur
+ fi
+ rm -r $temp_data_dir
+else
+ echo "$0: Expected $data/wav.scp to exist"
+ exit 1
+fi
+
+len1=$(wc -l < $data/wav.scp)
+len2=$(wc -l < $data/reco2dur)
+if [ "$len1" != "$len2" ]; then
+ echo "$0: warning: length of reco2dur does not equal that of wav.scp, $len2 != $len1"
+ if [ $len1 -gt $[$len2*2] ]; then
+ echo "$0: less than half of recordings got a duration: failing."
+ exit 1
+ fi
+fi
+
+echo "$0: computed $data/reco2dur"
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
new file mode 100755
index 0000000..6b161b3
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+
+# This script operates on a data directory, such as in data/train/,
+# and writes new segments to stdout. The file 'segments' maps from
+# utterance to time offsets into a recording, with the format:
+# <utterance-id> <recording-id> <segment-begin> <segment-end>
+# This script assumes utterance and recording ids are the same (i.e., that
+# wav.scp is indexed by utterance), and uses durations from 'utt2dur',
+# created if necessary by get_utt2dur.sh.
+
+. ./path.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: $0 [options] <datadir>"
+ echo "e.g.:"
+ echo " $0 data/train > data/train/segments"
+ exit 1
+fi
+
+data=$1
+
+if [ ! -s $data/utt2dur ]; then
+ utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
+fi
+
+# <utt-id> <utt-id> 0 <utt-dur>
+awk '{ print $1, $1, 0, $2 }' $data/utt2dur
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
new file mode 100755
index 0000000..5ee7ea3
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
@@ -0,0 +1,135 @@
+#!/usr/bin/env bash
+
+# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+# This script operates on a data directory, such as in data/train/, and adds the
+# utt2dur file if it does not already exist. The file 'utt2dur' maps from
+# utterance to the duration of the utterance in seconds. This script works it
+# out from the 'segments' file, or, if not present, from the wav.scp file (it
+# first tries interrogating the headers, and if this fails, it reads the wave
+# files in entirely.)
+
+frame_shift=0.01
+cmd=run.pl
+nj=4
+read_entire_file=false
+
+. utils/parse_options.sh
+. ./path.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: $0 [options] <datadir>"
+ echo "e.g.:"
+ echo " $0 data/train"
+ echo " Options:"
+ echo " --frame-shift # frame shift in seconds. Only relevant when we are"
+ echo " # getting duration from feats.scp, and only if the "
+ echo " # file frame_shift does not exist (default: 0.01). "
+ exit 1
+fi
+
+export LC_ALL=C
+
+data=$1
+
+if [ -s $data/utt2dur ] && \
+ [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ]; then
+ echo "$0: $data/utt2dur already exists with the expected length. We won't recompute it."
+ exit 0;
+fi
+
+if [ -s $data/segments ]; then
+ echo "$0: working out $data/utt2dur from $data/segments"
+ awk '{len=$4-$3; print $1, len;}' < $data/segments > $data/utt2dur
+elif [[ -s $data/frame_shift && -f $data/utt2num_frames ]]; then
+ echo "$0: computing $data/utt2dur from $data/{frame_shift,utt2num_frames}."
+ frame_shift=$(cat $data/frame_shift) || exit 1
+ # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
+ awk -v fs=$frame_shift '{ $2=($2+1.5)*fs; print }' <$data/utt2num_frames >$data/utt2dur
+elif [ -f $data/wav.scp ]; then
+ echo "$0: segments file does not exist so getting durations from wave files"
+
+ # if the wav.scp contains only lines of the form
+ # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
+ if perl <$data/wav.scp -e '
+ while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
+ @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
+ $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
+ $utt = $A[0]; $sphere_file = $A[4];
+
+ if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
+ $sample_rate = -1; $sample_count = -1;
+ for ($n = 0; $n <= 30; $n++) {
+ $line = <F>;
+ if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
+ if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
+ if ($line =~ m/end_head/) { break; }
+ }
+ close(F);
+ if ($sample_rate == -1 || $sample_count == -1) {
+ die "could not parse sphere header from $sphere_file";
+ }
+ $duration = $sample_count * 1.0 / $sample_rate;
+ print "$utt $duration\n";
+ } ' > $data/utt2dur; then
+ echo "$0: successfully obtained utterance lengths from sphere-file headers"
+ else
+ echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration"
+ if ! command -v wav-to-duration >/dev/null; then
+ echo "$0: wav-to-duration is not on your path"
+ exit 1;
+ fi
+
+ if grep -q 'sox.*speed' $data/wav.scp; then
+ read_entire_file=true
+ echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
+ echo "... It is much faster if you call get_utt2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
+ echo "... perturb_data_dir_speed_3way.sh."
+ fi
+
+
+ num_utts=$(wc -l <$data/utt2spk)
+ if [ $nj -gt $num_utts ]; then
+ nj=$num_utts
+ fi
+
+ utils/data/split_data.sh --per-utt $data $nj
+ sdata=$data/split${nj}utt
+
+ $cmd JOB=1:$nj $data/log/get_durations.JOB.log \
+ wav-to-duration --read-entire-file=$read_entire_file \
+ scp:$sdata/JOB/wav.scp ark,t:$sdata/JOB/utt2dur || \
+ { echo "$0: there was a problem getting the durations"; exit 1; }
+
+ for n in `seq $nj`; do
+ cat $sdata/$n/utt2dur
+ done > $data/utt2dur
+ fi
+elif [ -f $data/feats.scp ]; then
+ echo "$0: wave file does not exist so getting durations from feats files"
+ if [[ -s $data/frame_shift ]]; then
+ frame_shift=$(cat $data/frame_shift) || exit 1
+ echo "$0: using frame_shift=$frame_shift from file $data/frame_shift"
+ fi
+ # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
+ feat-to-len scp:$data/feats.scp ark,t:- |
+ awk -v frame_shift=$frame_shift '{print $1, ($2+1.5)*frame_shift}' >$data/utt2dur
+else
+ echo "$0: Expected $data/wav.scp, $data/segments or $data/feats.scp to exist"
+ exit 1
+fi
+
+len1=$(wc -l < $data/utt2spk)
+len2=$(wc -l < $data/utt2dur)
+if [ "$len1" != "$len2" ]; then
+ echo "$0: warning: length of utt2dur does not equal that of utt2spk, $len2 != $len1"
+ if [ $len1 -gt $[$len2*2] ]; then
+ echo "$0: less than half of utterances got a duration: failing."
+ exit 1
+ fi
+fi
+
+echo "$0: computed $data/utt2dur"
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/data/split_data.sh b/egs/alimeeting/sa-asr/utils/data/split_data.sh
new file mode 100755
index 0000000..8aa71a1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/data/split_data.sh
@@ -0,0 +1,160 @@
+#!/usr/bin/env bash
+# Copyright 2010-2013 Microsoft Corporation
+# Johns Hopkins University (Author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+split_per_spk=true
+if [ "$1" == "--per-utt" ]; then
+ split_per_spk=false
+ shift
+fi
+
+if [ $# != 2 ]; then
+ echo "Usage: $0 [--per-utt] <data-dir> <num-to-split>"
+ echo "E.g.: $0 data/train 50"
+ echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the "
+ echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}."
+ echo ""
+ echo "This script will not split the data-dir if it detects that the output is newer than the input."
+ echo "By default it splits per speaker (so each speaker is in only one split dir),"
+ echo "but with the --per-utt option it will ignore the speaker information while splitting."
+ exit 1
+fi
+
+data=$1
+numsplit=$2
+
+if ! [ "$numsplit" -gt 0 ]; then
+ echo "Invalid num-split argument $numsplit";
+ exit 1;
+fi
+
+if $split_per_spk; then
+ warning_opt=
+else
+ # suppress warnings from filter_scps.pl about 'some input lines were output
+ # to multiple files'.
+ warning_opt="--no-warn"
+fi
+
+n=0;
+feats=""
+wavs=""
+utt2spks=""
+texts=""
+
+nu=`cat $data/utt2spk | wc -l`
+nf=`cat $data/feats.scp 2>/dev/null | wc -l`
+nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
+if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
+ echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
+ echo "** use utils/fix_data_dir.sh $data to fix this."
+fi
+if [ -f $data/text ] && [ $nu -ne $nt ]; then
+ echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
+ echo "** use utils/fix_data_dir.sh to fix this."
+fi
+
+
+if $split_per_spk; then
+ utt2spk_opt="--utt2spk=$data/utt2spk"
+ utt=""
+else
+ utt2spk_opt=
+ utt="utt"
+fi
+
+s1=$data/split${numsplit}${utt}/1
+if [ ! -d $s1 ]; then
+ need_to_split=true
+else
+ need_to_split=false
+ for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \
+ vad.scp segments reco2file_and_channel utt2lang; do
+ if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then
+ need_to_split=true
+ fi
+ done
+fi
+
+if ! $need_to_split; then
+ exit 0;
+fi
+
+utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done)
+
+directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done)
+
+# if this mkdir fails due to argument-list being too long, iterate.
+if ! mkdir -p $directories >&/dev/null; then
+ for n in `seq $numsplit`; do
+ mkdir -p $data/split${numsplit}${utt}/$n
+ done
+fi
+
+# If lockfile is not installed, just don't lock it. It's not a big deal.
+which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock
+trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM
+
+utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
+
+for n in `seq $numsplit`; do
+ dsn=$data/split${numsplit}${utt}/$n
+ utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
+done
+
+maybe_wav_scp=
+if [ ! -f $data/segments ]; then
+ maybe_wav_scp=wav.scp # If there is no segments file, then wav file is
+ # indexed per utt.
+fi
+
+# split some things that are indexed by utterance.
+for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do
+ if [ -f $data/$f ]; then
+ utils/filter_scps.pl JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
+ fi
+done
+
+# split some things that are indexed by speaker
+for f in spk2gender spk2warp cmvn.scp; do
+ if [ -f $data/$f ]; then
+ utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/spk2utt $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
+ fi
+done
+
+if [ -f $data/segments ]; then
+ utils/filter_scps.pl JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/utt2spk $data/segments $data/split${numsplit}${utt}/JOB/segments || exit 1
+ for n in `seq $numsplit`; do
+ dsn=$data/split${numsplit}${utt}/$n
+ awk '{print $2;}' $dsn/segments | sort | uniq > $dsn/tmp.reco # recording-ids.
+ done
+ if [ -f $data/reco2file_and_channel ]; then
+ utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2file_and_channel \
+ $data/split${numsplit}${utt}/JOB/reco2file_and_channel || exit 1
+ fi
+ if [ -f $data/wav.scp ]; then
+ utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \
+ $data/split${numsplit}${utt}/JOB/wav.scp || exit 1
+ fi
+ for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done
+fi
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/utils/filter_scp.pl b/egs/alimeeting/sa-asr/utils/filter_scp.pl
new file mode 100755
index 0000000..b76d37f
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/filter_scp.pl
@@ -0,0 +1,87 @@
+#!/usr/bin/env perl
+# Copyright 2010-2012 Microsoft Corporation
+# Johns Hopkins University (author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This script takes a list of utterance-ids or any file whose first field
+# of each line is an utterance-id, and filters an scp
+# file (or any file whose "n-th" field is an utterance id), printing
+# out only those lines whose "n-th" field is in id_list. The index of
+# the "n-th" field is 1, by default, but can be changed by using
+# the -f <n> switch
+
+$exclude = 0;
+$field = 1;
+$shifted = 0;
+
+do {
+ $shifted=0;
+ if ($ARGV[0] eq "--exclude") {
+ $exclude = 1;
+ shift @ARGV;
+ $shifted=1;
+ }
+ if ($ARGV[0] eq "-f") {
+ $field = $ARGV[1];
+ shift @ARGV; shift @ARGV;
+ $shifted=1
+ }
+} while ($shifted);
+
+if(@ARGV < 1 || @ARGV > 2) {
+ die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
+ "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
+ "Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
+ "only the lines that were *not* in id_list.\n" .
+ "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
+ "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
+ "-f option, add 1 to the argument.\n" .
+ "See also: utils/filter_scp.pl .\n";
+}
+
+
+$idlist = shift @ARGV;
+open(F, "<$idlist") || die "Could not open id-list file $idlist";
+while(<F>) {
+ @A = split;
+ @A>=1 || die "Invalid id-list file line $_";
+ $seen{$A[0]} = 1;
+}
+
+if ($field == 1) { # Treat this as special case, since it is common.
+ while(<>) {
+ $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
+ # $1 is what we filter on.
+ if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
+ print $_;
+ }
+ }
+} else {
+ while(<>) {
+ @A = split;
+ @A > 0 || die "Invalid scp file line $_";
+ @A >= $field || die "Invalid scp file line $_";
+ if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
+ print $_;
+ }
+ }
+}
+
+# tests:
+# the following should print "foo 1"
+# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
+# the following should print "bar 2".
+# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
diff --git a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh
new file mode 100755
index 0000000..ed4710d
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh
@@ -0,0 +1,215 @@
+#!/usr/bin/env bash
+
+# This script makes sure that only the segments present in
+# all of "feats.scp", "wav.scp" [if present], segments [if present]
+# text, and utt2spk are present in any of them.
+# It puts the original contents of data-dir into
+# data-dir/.backup
+
+cmd="$@"
+
+utt_extra_files=
+spk_extra_files=
+
+. utils/parse_options.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: utils/data/fix_data_dir.sh <data-dir>"
+ echo "e.g.: utils/data/fix_data_dir.sh data/train"
+ echo "This script helps ensure that the various files in a data directory"
+ echo "are correctly sorted and filtered, for example removing utterances"
+ echo "that have no features (if feats.scp is present)"
+ exit 1
+fi
+
+data=$1
+
+if [ -f $data/images.scp ]; then
+ image/fix_data_dir.sh $cmd
+ exit $?
+fi
+
+mkdir -p $data/.backup
+
+[ ! -d $data ] && echo "$0: no such directory $data" && exit 1;
+
+[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1;
+
+set -e -o pipefail -u
+
+tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
+trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
+
+export LC_ALL=C
+
+function check_sorted {
+ file=$1
+ sort -k1,1 -u <$file >$file.tmp
+ if ! cmp -s $file $file.tmp; then
+ echo "$0: file $1 is not in sorted order or not unique, sorting it"
+ mv $file.tmp $file
+ else
+ rm $file.tmp
+ fi
+}
+
+for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \
+ reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do
+ if [ -f $data/$x ]; then
+ cp $data/$x $data/.backup/$x
+ check_sorted $data/$x
+ fi
+done
+
+
+function filter_file {
+ filter=$1
+ file_to_filter=$2
+ cp $file_to_filter ${file_to_filter}.tmp
+ utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter
+ if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then
+ length1=$(cat ${file_to_filter}.tmp | wc -l)
+ length2=$(cat ${file_to_filter} | wc -l)
+ if [ $length1 -ne $length2 ]; then
+ echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter."
+ fi
+ fi
+ rm $file_to_filter.tmp
+}
+
+function filter_recordings {
+ # We call this once before the stage when we filter on utterance-id, and once
+ # after.
+
+ if [ -f $data/segments ]; then
+ # We have a segments file -> we need to filter this and the file wav.scp, and
+ # reco2file_and_utt, if it exists, to make sure they have the same list of
+ # recording-ids.
+
+ if [ ! -f $data/wav.scp ]; then
+ echo "$0: $data/segments exists but not $data/wav.scp"
+ exit 1;
+ fi
+ awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings
+ n1=$(cat $tmpdir/recordings | wc -l)
+ [ ! -s $tmpdir/recordings ] && \
+ echo "Empty list of recordings (bad file $data/segments)?" && exit 1;
+ utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp
+ mv $tmpdir/recordings.tmp $tmpdir/recordings
+
+
+ cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
+ filter_file $tmpdir/recordings $data/segments
+ cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
+ rm $data/segments.tmp
+
+ filter_file $tmpdir/recordings $data/wav.scp
+ [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel
+ [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur
+ true
+ fi
+}
+
+function filter_speakers {
+ # throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
+ utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ for s in cmvn.scp spk2gender; do
+ f=$data/$s
+ if [ -f $f ]; then
+ filter_file $f $tmpdir/speakers
+ fi
+ done
+
+ filter_file $tmpdir/speakers $data/spk2utt
+ utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
+
+ for s in cmvn.scp spk2gender $spk_extra_files; do
+ f=$data/$s
+ if [ -f $f ]; then
+ filter_file $tmpdir/speakers $f
+ fi
+ done
+}
+
+function filter_utts {
+ cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
+
+ ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \
+ echo "utt2spk is not in sorted order (fix this yourself)" && exit 1;
+
+ ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \
+ echo "utt2spk is not in sorted order when sorted first on speaker-id " && \
+ echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
+
+ ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \
+ echo "spk2utt is not in sorted order (fix this yourself)" && exit 1;
+
+ if [ -f $data/utt2uniq ]; then
+ ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \
+ echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1;
+ fi
+
+ maybe_wav=
+ maybe_reco2dur=
+ [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist.
+ [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts
+
+ maybe_utt2dur=
+ if [ -f $data/utt2dur ]; then
+ cat $data/utt2dur | \
+ awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1
+ maybe_utt2dur=utt2dur.ok
+ fi
+
+ maybe_utt2num_frames=
+ if [ -f $data/utt2num_frames ]; then
+ cat $data/utt2num_frames | \
+ awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1
+ maybe_utt2num_frames=utt2num_frames.ok
+ fi
+
+ for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do
+ if [ -f $data/$x ]; then
+ utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp
+ mv $tmpdir/utts.tmp $tmpdir/utts
+ fi
+ done
+ rm $data/utt2dur.ok 2>/dev/null || true
+ rm $data/utt2num_frames.ok 2>/dev/null || true
+
+ [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \
+ rm $tmpdir/utts && exit 1;
+
+
+ if [ -f $data/utt2spk ]; then
+ new_nutts=$(cat $tmpdir/utts | wc -l)
+ old_nutts=$(cat $data/utt2spk | wc -l)
+ if [ $new_nutts -ne $old_nutts ]; then
+ echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts"
+ else
+ echo "fix_data_dir.sh: kept all $old_nutts utterances."
+ fi
+ fi
+
+ for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do
+ if [ -f $data/$x ]; then
+ cp $data/$x $data/.backup/$x
+ if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then
+ utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x
+ fi
+ fi
+ done
+
+}
+
+filter_recordings
+filter_speakers
+filter_utts
+filter_speakers
+filter_recordings
+
+utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+
+echo "fix_data_dir.sh: old files are kept in $data/.backup"
diff --git a/egs/alimeeting/sa-asr/utils/parse_options.sh b/egs/alimeeting/sa-asr/utils/parse_options.sh
new file mode 100755
index 0000000..71fb9e5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/parse_options.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
+# Arnab Ghoshal, Karel Vesely
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parse command-line options.
+# To be sourced by another script (as in ". parse_options.sh").
+# Option format is: --option-name arg
+# and shell variable "option_name" gets set to value "arg."
+# The exception is --help, which takes no arguments, but prints the
+# $help_message variable (if defined).
+
+
+###
+### The --config file options have lower priority to command line
+### options, so we need to import them first...
+###
+
+# Now import all the configs specified by command-line, in left-to-right order
+for ((argpos=1; argpos<$#; argpos++)); do
+ if [ "${!argpos}" == "--config" ]; then
+ argpos_plus1=$((argpos+1))
+ config=${!argpos_plus1}
+ [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
+ . $config # source the config file.
+ fi
+done
+
+
+###
+### Now we process the command line options
+###
+while true; do
+ [ -z "${1:-}" ] && break; # break if there are no arguments
+ case "$1" in
+ # If the enclosing script is called with --help option, print the help
+ # message and exit. Scripts should put help messages in $help_message
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
+ else printf "$help_message\n" 1>&2 ; fi;
+ exit 0 ;;
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
+ exit 1 ;;
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
+ # then work out the variable name as $name, which will equal "foo_bar".
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
+ # Next we test whether the variable in question is undefned-- if so it's
+ # an invalid option and we die. Note: $0 evaluates to the name of the
+ # enclosing script.
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
+ # is undefined. We then have to wrap this test inside "eval" because
+ # foo_bar is itself inside a variable ($name).
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
+
+ oldval="`eval echo \\$$name`";
+ # Work out whether we seem to be expecting a Boolean argument.
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
+ was_bool=true;
+ else
+ was_bool=false;
+ fi
+
+ # Set the variable to the right value-- the escaped quotes make it work if
+ # the option had spaces, like --cmd "queue.pl -sync y"
+ eval $name=\"$2\";
+
+ # Check that Boolean-valued arguments are really Boolean.
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
+ exit 1;
+ fi
+ shift 2;
+ ;;
+ *) break;
+ esac
+done
+
+
+# Check for an empty argument to the --cmd option, which can easily occur as a
+# result of scripting errors.
+[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
+
+
+true; # so this script returns exit code 0.
diff --git a/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl
new file mode 100755
index 0000000..23992f2
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl
@@ -0,0 +1,27 @@
+#!/usr/bin/env perl
+# Copyright 2010-2011 Microsoft Corporation
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+while(<>){
+ @A = split(" ", $_);
+ @A > 1 || die "Invalid line in spk2utt file: $_";
+ $s = shift @A;
+ foreach $u ( @A ) {
+ print "$u $s\n";
+ }
+}
+
+
diff --git a/egs/alimeeting/sa-asr/utils/split_scp.pl b/egs/alimeeting/sa-asr/utils/split_scp.pl
new file mode 100755
index 0000000..0876dcb
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/split_scp.pl
@@ -0,0 +1,246 @@
+#!/usr/bin/env perl
+
+# Copyright 2010-2011 Microsoft Corporation
+
+# See ../../COPYING for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This program splits up any kind of .scp or archive-type file.
+# If there is no utt2spk option it will work on any text file and
+# will split it up with an approximately equal number of lines in
+# each but.
+# With the --utt2spk option it will work on anything that has the
+# utterance-id as the first entry on each line; the utt2spk file is
+# of the form "utterance speaker" (on each line).
+# It splits it into equal size chunks as far as it can. If you use the utt2spk
+# option it will make sure these chunks coincide with speaker boundaries. In
+# this case, if there are more chunks than speakers (and in some other
+# circumstances), some of the resulting chunks will be empty and it will print
+# an error message and exit with nonzero status.
+# You will normally call this like:
+# split_scp.pl scp scp.1 scp.2 scp.3 ...
+# or
+# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
+# Note that you can use this script to split the utt2spk file itself,
+# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
+
+# You can also call the scripts like:
+# split_scp.pl -j 3 0 scp scp.0
+# [note: with this option, it assumes zero-based indexing of the split parts,
+# i.e. the second number must be 0 <= n < num-jobs.]
+
+use warnings;
+
+$num_jobs = 0;
+$job_id = 0;
+$utt2spk_file = "";
+$one_based = 0;
+
+for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
+ if ($ARGV[0] eq "-j") {
+ shift @ARGV;
+ $num_jobs = shift @ARGV;
+ $job_id = shift @ARGV;
+ }
+ if ($ARGV[0] =~ /--utt2spk=(.+)/) {
+ $utt2spk_file=$1;
+ shift;
+ }
+ if ($ARGV[0] eq '--one-based') {
+ $one_based = 1;
+ shift @ARGV;
+ }
+}
+
+if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
+ $job_id - $one_based >= $num_jobs)) {
+ die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
+ ($one_based ? " --one-based" : "") . "'\n"
+}
+
+$one_based
+ and $job_id--;
+
+if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
+ die
+"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
+ or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
+ ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
+}
+
+$error = 0;
+$inscp = shift @ARGV;
+if ($num_jobs == 0) { # without -j option
+ @OUTPUTS = @ARGV;
+} else {
+ for ($j = 0; $j < $num_jobs; $j++) {
+ if ($j == $job_id) {
+ if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
+ else { push @OUTPUTS, "-"; }
+ } else {
+ push @OUTPUTS, "/dev/null";
+ }
+ }
+}
+
+if ($utt2spk_file ne "") { # We have the --utt2spk option...
+ open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
+ while(<$u_fh>) {
+ @A = split;
+ @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
+ ($u,$s) = @A;
+ $utt2spk{$u} = $s;
+ }
+ close $u_fh;
+ open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+ @spkrs = ();
+ while(<$i_fh>) {
+ @A = split;
+ if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
+ $u = $A[0];
+ $s = $utt2spk{$u};
+ defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
+ if(!defined $spk_count{$s}) {
+ push @spkrs, $s;
+ $spk_count{$s} = 0;
+ $spk_data{$s} = []; # ref to new empty array.
+ }
+ $spk_count{$s}++;
+ push @{$spk_data{$s}}, $_;
+ }
+ # Now split as equally as possible ..
+ # First allocate spks to files by allocating an approximately
+ # equal number of speakers.
+ $numspks = @spkrs; # number of speakers.
+ $numscps = @OUTPUTS; # number of output files.
+ if ($numspks < $numscps) {
+ die "$0: Refusing to split data because number of speakers $numspks " .
+ "is less than the number of output .scp files $numscps\n";
+ }
+ for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+ $scparray[$scpidx] = []; # [] is array reference.
+ }
+ for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
+ $scpidx = int(($spkidx*$numscps) / $numspks);
+ $spk = $spkrs[$spkidx];
+ push @{$scparray[$scpidx]}, $spk;
+ $scpcount[$scpidx] += $spk_count{$spk};
+ }
+
+ # Now will try to reassign beginning + ending speakers
+ # to different scp's and see if it gets more balanced.
+ # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
+ # We can show that if considering changing just 2 scp's, we minimize
+ # this by minimizing the squared difference in sizes. This is
+ # equivalent to minimizing the absolute difference in sizes. This
+ # shows this method is bound to converge.
+
+ $changed = 1;
+ while($changed) {
+ $changed = 0;
+ for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+ # First try to reassign ending spk of this scp.
+ if($scpidx < $numscps-1) {
+ $sz = @{$scparray[$scpidx]};
+ if($sz > 0) {
+ $spk = $scparray[$scpidx]->[$sz-1];
+ $count = $spk_count{$spk};
+ $nutt1 = $scpcount[$scpidx];
+ $nutt2 = $scpcount[$scpidx+1];
+ if( abs( ($nutt2+$count) - ($nutt1-$count))
+ < abs($nutt2 - $nutt1)) { # Would decrease
+ # size-diff by reassigning spk...
+ $scpcount[$scpidx+1] += $count;
+ $scpcount[$scpidx] -= $count;
+ pop @{$scparray[$scpidx]};
+ unshift @{$scparray[$scpidx+1]}, $spk;
+ $changed = 1;
+ }
+ }
+ }
+ if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
+ $spk = $scparray[$scpidx]->[0];
+ $count = $spk_count{$spk};
+ $nutt1 = $scpcount[$scpidx-1];
+ $nutt2 = $scpcount[$scpidx];
+ if( abs( ($nutt2-$count) - ($nutt1+$count))
+ < abs($nutt2 - $nutt1)) { # Would decrease
+ # size-diff by reassigning spk...
+ $scpcount[$scpidx-1] += $count;
+ $scpcount[$scpidx] -= $count;
+ shift @{$scparray[$scpidx]};
+ push @{$scparray[$scpidx-1]}, $spk;
+ $changed = 1;
+ }
+ }
+ }
+ }
+ # Now print out the files...
+ for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+ $scpfile = $OUTPUTS[$scpidx];
+ ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
+ : open($f_fh, '>&', \*STDOUT)) ||
+ die "$0: Could not open scp file $scpfile for writing: $!\n";
+ $count = 0;
+ if(@{$scparray[$scpidx]} == 0) {
+ print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
+ "$scpfile (too many splits and too few speakers?)\n";
+ $error = 1;
+ } else {
+ foreach $spk ( @{$scparray[$scpidx]} ) {
+ print $f_fh @{$spk_data{$spk}};
+ $count += $spk_count{$spk};
+ }
+ $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
+ }
+ close($f_fh);
+ }
+} else {
+ # This block is the "normal" case where there is no --utt2spk
+ # option and we just break into equal size chunks.
+
+ open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+
+ $numscps = @OUTPUTS; # size of array.
+ @F = ();
+ while(<$i_fh>) {
+ push @F, $_;
+ }
+ $numlines = @F;
+ if($numlines == 0) {
+ print STDERR "$0: error: empty input scp file $inscp\n";
+ $error = 1;
+ }
+ $linesperscp = int( $numlines / $numscps); # the "whole part"..
+ $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
+ $remainder = $numlines - ($linesperscp * $numscps);
+ ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
+ # [just doing int() rounds down].
+ $n = 0;
+ for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
+ $scpfile = $OUTPUTS[$scpidx];
+ ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
+ : open($o_fh, '>&', \*STDOUT)) ||
+ die "$0: Could not open scp file $scpfile for writing: $!\n";
+ for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
+ print $o_fh $F[$n++];
+ }
+ close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
+ }
+ $n == $numlines || die "$n != $numlines [code error]";
+}
+
+exit ($error);
diff --git a/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl
new file mode 100755
index 0000000..6e0e438
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl
@@ -0,0 +1,38 @@
+#!/usr/bin/env perl
+# Copyright 2010-2011 Microsoft Corporation
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+# converts an utt2spk file to a spk2utt file.
+# Takes input from the stdin or from a file argument;
+# output goes to the standard out.
+
+if ( @ARGV > 1 ) {
+ die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
+}
+
+while(<>){
+ @A = split(" ", $_);
+ @A == 2 || die "Invalid line in utt2spk file: $_";
+ ($u,$s) = @A;
+ if(!$seen_spk{$s}) {
+ $seen_spk{$s} = 1;
+ push @spklist, $s;
+ }
+ push (@{$spk_hash{$s}}, "$u");
+}
+foreach $s (@spklist) {
+ $l = join(' ',@{$spk_hash{$s}});
+ print "$s $l\n";
+}
diff --git a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh
new file mode 100755
index 0000000..3eec443
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh
@@ -0,0 +1,404 @@
+#!/usr/bin/env bash
+
+cmd="$@"
+
+no_feats=false
+no_wav=false
+no_text=false
+no_spk_sort=false
+non_print=false
+
+
+function show_help
+{
+ echo "Usage: $0 [--no-feats] [--no-text] [--non-print] [--no-wav] [--no-spk-sort] <data-dir>"
+ echo "The --no-xxx options mean that the script does not require "
+ echo "xxx.scp to be present, but it will check it if it is present."
+ echo "--no-spk-sort means that the script does not require the utt2spk to be "
+ echo "sorted by the speaker-id in addition to being sorted by utterance-id."
+ echo "--non-print ignore the presence of non-printable characters."
+ echo "By default, utt2spk is expected to be sorted by both, which can be "
+ echo "achieved by making the speaker-id prefixes of the utterance-ids"
+ echo "e.g.: $0 data/train"
+}
+
+while [ $# -ne 0 ] ; do
+ case "$1" in
+ "--no-feats")
+ no_feats=true;
+ ;;
+ "--no-text")
+ no_text=true;
+ ;;
+ "--non-print")
+ non_print=true;
+ ;;
+ "--no-wav")
+ no_wav=true;
+ ;;
+ "--no-spk-sort")
+ no_spk_sort=true;
+ ;;
+ *)
+ if ! [ -z "$data" ] ; then
+ show_help;
+ exit 1
+ fi
+ data=$1
+ ;;
+ esac
+ shift
+done
+
+
+
+if [ ! -d $data ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+if [ -f $data/images.scp ]; then
+ cmd=${cmd/--no-wav/} # remove --no-wav if supplied
+ image/validate_data_dir.sh $cmd
+ exit $?
+fi
+
+for f in spk2utt utt2spk; do
+ if [ ! -f $data/$f ]; then
+ echo "$0: no such file $f"
+ exit 1;
+ fi
+ if [ ! -s $data/$f ]; then
+ echo "$0: empty file $f"
+ exit 1;
+ fi
+done
+
+! cat $data/utt2spk | awk '{if (NF != 2) exit(1); }' && \
+ echo "$0: $data/utt2spk has wrong format." && exit;
+
+ns=$(wc -l < $data/spk2utt)
+if [ "$ns" == 1 ]; then
+ echo "$0: WARNING: you have only one speaker. This probably a bad idea."
+ echo " Search for the word 'bold' in http://kaldi-asr.org/doc/data_prep.html"
+ echo " for more information."
+fi
+
+
+tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
+trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
+
+export LC_ALL=C
+
+function check_sorted_and_uniq {
+ ! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1;
+ ! awk '{print $1}' < $1 | sort -uC && echo "$0: file $1 is not sorted or has duplicates" && exit 1;
+}
+
+function partial_diff {
+ diff -U1 $1 $2 | (head -n 6; echo "..."; tail -n 6)
+ n1=`cat $1 | wc -l`
+ n2=`cat $2 | wc -l`
+ echo "[Lengths are $1=$n1 versus $2=$n2]"
+}
+
+check_sorted_and_uniq $data/utt2spk
+
+if ! $no_spk_sort; then
+ ! sort -k2 -C $data/utt2spk && \
+ echo "$0: utt2spk is not in sorted order when sorted first on speaker-id " && \
+ echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
+fi
+
+check_sorted_and_uniq $data/spk2utt
+
+! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
+ <(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \
+ echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
+
+cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
+
+if [ ! -f $data/text ] && ! $no_text; then
+ echo "$0: no such file $data/text (if this is by design, specify --no-text)"
+ exit 1;
+fi
+
+num_utts=`cat $tmpdir/utts | wc -l`
+if ! $no_text; then
+ if ! $non_print; then
+ if locale -a | grep "C.UTF-8" >/dev/null; then
+ L=C.UTF-8
+ else
+ L=en_US.UTF-8
+ fi
+ n_non_print=$(LC_ALL="$L" grep -c '[^[:print:][:space:]]' $data/text) && \
+ echo "$0: text contains $n_non_print lines with non-printable characters" &&\
+ exit 1;
+ fi
+ utils/validate_text.pl $data/text || exit 1;
+ check_sorted_and_uniq $data/text
+ text_len=`cat $data/text | wc -l`
+ illegal_sym_list="<s> </s> #0"
+ for x in $illegal_sym_list; do
+ if grep -w "$x" $data/text > /dev/null; then
+ echo "$0: Error: in $data, text contains illegal symbol $x"
+ exit 1;
+ fi
+ done
+ awk '{print $1}' < $data/text > $tmpdir/utts.txt
+ if ! cmp -s $tmpdir/utts{,.txt}; then
+ echo "$0: Error: in $data, utterance lists extracted from utt2spk and text"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.txt}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/segments ] && [ ! -f $data/wav.scp ]; then
+ echo "$0: in directory $data, segments file exists but no wav.scp"
+ exit 1;
+fi
+
+
+if [ ! -f $data/wav.scp ] && ! $no_wav; then
+ echo "$0: no such file $data/wav.scp (if this is by design, specify --no-wav)"
+ exit 1;
+fi
+
+if [ -f $data/wav.scp ]; then
+ check_sorted_and_uniq $data/wav.scp
+
+ if grep -E -q '^\S+\s+~' $data/wav.scp; then
+ # note: it's not a good idea to have any kind of tilde in wav.scp, even if
+ # part of a command, as it would cause compatibility problems if run by
+ # other users, but this used to be not checked for so we let it slide unless
+ # it's something of the form "foo ~/foo.wav" (i.e. a plain file name) which
+ # would definitely cause problems as the fopen system call does not do
+ # tilde expansion.
+ echo "$0: Please do not use tilde (~) in your wav.scp."
+ exit 1;
+ fi
+
+ if [ -f $data/segments ]; then
+
+ check_sorted_and_uniq $data/segments
+ # We have a segments file -> interpret wav file as "recording-ids" not utterance-ids.
+ ! cat $data/segments | \
+ awk '{if (NF != 4 || $4 <= $3) { print "Bad line in segments file", $0; exit(1); }}' && \
+ echo "$0: badly formatted segments file" && exit 1;
+
+ segments_len=`cat $data/segments | wc -l`
+ if [ -f $data/text ]; then
+ ! cmp -s $tmpdir/utts <(awk '{print $1}' <$data/segments) && \
+ echo "$0: Utterance list differs between $data/utt2spk and $data/segments " && \
+ echo "$0: Lengths are $segments_len vs $num_utts" && \
+ exit 1
+ fi
+
+ cat $data/segments | awk '{print $2}' | sort | uniq > $tmpdir/recordings
+ awk '{print $1}' $data/wav.scp > $tmpdir/recordings.wav
+ if ! cmp -s $tmpdir/recordings{,.wav}; then
+ echo "$0: Error: in $data, recording-ids extracted from segments and wav.scp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/recordings{,.wav}
+ exit 1;
+ fi
+ if [ -f $data/reco2file_and_channel ]; then
+ # this file is needed only for ctm scoring; it's indexed by recording-id.
+ check_sorted_and_uniq $data/reco2file_and_channel
+ ! cat $data/reco2file_and_channel | \
+ awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
+ if ( NF == 3 && $3 == "1" ) {
+ warning_issued = 1;
+ } else {
+ print "Bad line ", $0; exit 1;
+ }
+ }
+ }
+ END {
+ if (warning_issued == 1) {
+ print "The channel should be marked as A or B, not 1! You should change it ASAP! "
+ }
+ }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
+ cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/recordings.r2fc
+ if ! cmp -s $tmpdir/recordings{,.r2fc}; then
+ echo "$0: Error: in $data, recording-ids extracted from segments and reco2file_and_channel"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/recordings{,.r2fc}
+ exit 1;
+ fi
+ fi
+ else
+ # No segments file -> assume wav.scp indexed by utterance.
+ cat $data/wav.scp | awk '{print $1}' > $tmpdir/utts.wav
+ if ! cmp -s $tmpdir/utts{,.wav}; then
+ echo "$0: Error: in $data, utterance lists extracted from utt2spk and wav.scp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.wav}
+ exit 1;
+ fi
+
+ if [ -f $data/reco2file_and_channel ]; then
+ # this file is needed only for ctm scoring; it's indexed by recording-id.
+ check_sorted_and_uniq $data/reco2file_and_channel
+ ! cat $data/reco2file_and_channel | \
+ awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
+ if ( NF == 3 && $3 == "1" ) {
+ warning_issued = 1;
+ } else {
+ print "Bad line ", $0; exit 1;
+ }
+ }
+ }
+ END {
+ if (warning_issued == 1) {
+ print "The channel should be marked as A or B, not 1! You should change it ASAP! "
+ }
+ }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
+ cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/utts.r2fc
+ if ! cmp -s $tmpdir/utts{,.r2fc}; then
+ echo "$0: Error: in $data, utterance-ids extracted from segments and reco2file_and_channel"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.r2fc}
+ exit 1;
+ fi
+ fi
+ fi
+fi
+
+if [ ! -f $data/feats.scp ] && ! $no_feats; then
+ echo "$0: no such file $data/feats.scp (if this is by design, specify --no-feats)"
+ exit 1;
+fi
+
+if [ -f $data/feats.scp ]; then
+ check_sorted_and_uniq $data/feats.scp
+ cat $data/feats.scp | awk '{print $1}' > $tmpdir/utts.feats
+ if ! cmp -s $tmpdir/utts{,.feats}; then
+ echo "$0: Error: in $data, utterance-ids extracted from utt2spk and features"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.feats}
+ exit 1;
+ fi
+fi
+
+
+if [ -f $data/cmvn.scp ]; then
+ check_sorted_and_uniq $data/cmvn.scp
+ cat $data/cmvn.scp | awk '{print $1}' > $tmpdir/speakers.cmvn
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ if ! cmp -s $tmpdir/speakers{,.cmvn}; then
+ echo "$0: Error: in $data, speaker lists extracted from spk2utt and cmvn"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/speakers{,.cmvn}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/spk2gender ]; then
+ check_sorted_and_uniq $data/spk2gender
+ ! cat $data/spk2gender | awk '{if (!((NF == 2 && ($2 == "m" || $2 == "f")))) exit 1; }' && \
+ echo "$0: Mal-formed spk2gender file" && exit 1;
+ cat $data/spk2gender | awk '{print $1}' > $tmpdir/speakers.spk2gender
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ if ! cmp -s $tmpdir/speakers{,.spk2gender}; then
+ echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2gender"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/speakers{,.spk2gender}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/spk2warp ]; then
+ check_sorted_and_uniq $data/spk2warp
+ ! cat $data/spk2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
+ echo "$0: Mal-formed spk2warp file" && exit 1;
+ cat $data/spk2warp | awk '{print $1}' > $tmpdir/speakers.spk2warp
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ if ! cmp -s $tmpdir/speakers{,.spk2warp}; then
+ echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2warp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/speakers{,.spk2warp}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/utt2warp ]; then
+ check_sorted_and_uniq $data/utt2warp
+ ! cat $data/utt2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
+ echo "$0: Mal-formed utt2warp file" && exit 1;
+ cat $data/utt2warp | awk '{print $1}' > $tmpdir/utts.utt2warp
+ cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
+ if ! cmp -s $tmpdir/utts{,.utt2warp}; then
+ echo "$0: Error: in $data, utterance lists extracted from utt2spk and utt2warp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.utt2warp}
+ exit 1;
+ fi
+fi
+
+# check some optionally-required things
+for f in vad.scp utt2lang utt2uniq; do
+ if [ -f $data/$f ]; then
+ check_sorted_and_uniq $data/$f
+ if ! cmp -s <( awk '{print $1}' $data/utt2spk ) \
+ <( awk '{print $1}' $data/$f ); then
+ echo "$0: error: in $data, $f and utt2spk do not have identical utterance-id list"
+ exit 1;
+ fi
+ fi
+done
+
+
+if [ -f $data/utt2dur ]; then
+ check_sorted_and_uniq $data/utt2dur
+ cat $data/utt2dur | awk '{print $1}' > $tmpdir/utts.utt2dur
+ if ! cmp -s $tmpdir/utts{,.utt2dur}; then
+ echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2dur file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.utt2dur}
+ exit 1;
+ fi
+ cat $data/utt2dur | \
+ awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line utt2dur:" NR ":" $0; exit(1) }}' || exit 1
+fi
+
+if [ -f $data/utt2num_frames ]; then
+ check_sorted_and_uniq $data/utt2num_frames
+ cat $data/utt2num_frames | awk '{print $1}' > $tmpdir/utts.utt2num_frames
+ if ! cmp -s $tmpdir/utts{,.utt2num_frames}; then
+ echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2num_frames file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.utt2num_frames}
+ exit 1
+ fi
+ awk <$data/utt2num_frames '{
+ if (NF != 2 || !($2 > 0) || $2 != int($2)) {
+ print "Bad line utt2num_frames:" NR ":" $0
+ exit 1 } }' || exit 1
+fi
+
+if [ -f $data/reco2dur ]; then
+ check_sorted_and_uniq $data/reco2dur
+ cat $data/reco2dur | awk '{print $1}' > $tmpdir/recordings.reco2dur
+ if [ -f $tmpdir/recordings ]; then
+ if ! cmp -s $tmpdir/recordings{,.reco2dur}; then
+ echo "$0: Error: in $data, recording-ids extracted from segments and reco2dur file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/recordings{,.reco2dur}
+ exit 1;
+ fi
+ else
+ if ! cmp -s $tmpdir/{utts,recordings.reco2dur}; then
+ echo "$0: Error: in $data, recording-ids extracted from wav.scp and reco2dur file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/{utts,recordings.reco2dur}
+ exit 1;
+ fi
+ fi
+ cat $data/reco2dur | \
+ awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line : " $0; exit(1) }}' || exit 1
+fi
+
+
+echo "$0: Successfully validated data-directory $data"
diff --git a/egs/alimeeting/sa-asr/utils/validate_text.pl b/egs/alimeeting/sa-asr/utils/validate_text.pl
new file mode 100755
index 0000000..7f75cf1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils/validate_text.pl
@@ -0,0 +1,136 @@
+#!/usr/bin/env perl
+#
+#===============================================================================
+# Copyright 2017 Johns Hopkins University (author: Yenda Trmal <jtrmal@gmail.com>)
+# Johns Hopkins University (author: Daniel Povey)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+#===============================================================================
+
+# validation script for data/<dataset>/text
+# to be called (preferably) from utils/validate_data_dir.sh
+use strict;
+use warnings;
+use utf8;
+use Fcntl qw< SEEK_SET >;
+
+# this function reads the opened file (supplied as a first
+# parameter) into an array of lines. For each
+# line, it tests whether it's a valid utf-8 compatible
+# line. If all lines are valid utf-8, it returns the lines
+# decoded as utf-8, otherwise it assumes the file's encoding
+# is one of those 1-byte encodings, such as ISO-8859-x
+# or Windows CP-X.
+# Please recall we do not really care about
+# the actually encoding, we just need to
+# make sure the length of the (decoded) string
+# is correct (to make the output formatting looking right).
+sub get_utf8_or_bytestream {
+ use Encode qw(decode encode);
+ my $is_utf_compatible = 1;
+ my @unicode_lines;
+ my @raw_lines;
+ my $raw_text;
+ my $lineno = 0;
+ my $file = shift;
+
+ while (<$file>) {
+ $raw_text = $_;
+ last unless $raw_text;
+ if ($is_utf_compatible) {
+ my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ;
+ $is_utf_compatible = $is_utf_compatible && defined($decoded_text);
+ push @unicode_lines, $decoded_text;
+ } else {
+ #print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n";
+ ;
+ }
+ push @raw_lines, $raw_text;
+ $lineno += 1;
+ }
+
+ if (!$is_utf_compatible) {
+ return (0, @raw_lines);
+ } else {
+ return (1, @unicode_lines);
+ }
+}
+
+# check if the given unicode string contain unicode whitespaces
+# other than the usual four: TAB, LF, CR and SPACE
+sub validate_utf8_whitespaces {
+ my $unicode_lines = shift;
+ use feature 'unicode_strings';
+ for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) {
+ my $current_line = $unicode_lines->[$i];
+ if ((substr $current_line, -1) ne "\n"){
+ print STDERR "$0: The current line (nr. $i) has invalid newline\n";
+ return 1;
+ }
+ my @A = split(" ", $current_line);
+ my $utt_id = $A[0];
+ # we replace TAB, LF, CR, and SPACE
+ # this is to simplify the test
+ if ($current_line =~ /\x{000d}/) {
+ print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n";
+ return 1;
+ }
+ $current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g;
+ if ($current_line =~/\s/) {
+ print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n";
+ return 1;
+ }
+ }
+ return 0;
+}
+
+# checks if the text in the file (supplied as the argument) is utf-8 compatible
+# if yes, checks if it contains only allowed whitespaces. If no, then does not
+# do anything. The function seeks to the original position in the file after
+# reading the text.
+sub check_allowed_whitespace {
+ my $file = shift;
+ my $filename = shift;
+ my $pos = tell($file);
+ (my $is_utf, my @lines) = get_utf8_or_bytestream($file);
+ seek($file, $pos, SEEK_SET);
+ if ($is_utf) {
+ my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines);
+ if ($has_invalid_whitespaces) {
+ print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n";
+ return 0;
+ }
+ }
+ return 1;
+}
+
+if(@ARGV != 1) {
+ die "Usage: validate_text.pl <text-file>\n" .
+ "e.g.: validate_text.pl data/train/text\n";
+}
+
+my $text = shift @ARGV;
+
+if (-z "$text") {
+ print STDERR "$0: ERROR: file '$text' is empty or does not exist\n";
+ exit 1;
+}
+
+if(!open(FILE, "<$text")) {
+ print STDERR "$0: ERROR: failed to open $text\n";
+ exit 1;
+}
+
+check_allowed_whitespace(\*FILE, $text) or exit 1;
+close(FILE);
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 4722602..c18472f 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -40,7 +40,6 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
@@ -91,8 +90,6 @@
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@@ -111,7 +108,7 @@
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm_train_config, lm_file, None, device
)
scorers["lm"] = lm.lm
@@ -141,6 +138,13 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+ logging.info(f"Beam_search: {beam_search}")
+ logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
@@ -198,16 +202,7 @@
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
# a. To device
batch = to_device(batch, device=self.device)
@@ -355,6 +350,9 @@
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -408,6 +406,7 @@
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
+ mc=True,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
@@ -452,7 +451,7 @@
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
@@ -463,6 +462,9 @@
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}\n".format(text))
return asr_result_list
return _forward
@@ -637,4 +639,4 @@
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e10ebf4..e165531 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -288,6 +288,9 @@
if mode == "asr":
from funasr.bin.asr_inference import inference
return inference(**kwargs)
+ elif mode == "sa_asr":
+ from funasr.bin.sa_asr_inference import inference
+ return inference(**kwargs)
elif mode == "uniasr":
from funasr.bin.asr_inference_uniasr import inference
return inference(**kwargs)
@@ -342,4 +345,4 @@
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
index bba50da..c1e2cb2 100755
--- a/funasr/bin/asr_train.py
+++ b/funasr/bin/asr_train.py
@@ -2,6 +2,14 @@
import os
+import logging
+
+logging.basicConfig(
+ level='INFO',
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+)
+
from funasr.tasks.asr import ASRTask
@@ -27,7 +35,8 @@
args = parse_args()
# setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+ if args.ngpu > 0:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
@@ -38,9 +47,9 @@
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
- if args.batch_size is not None:
+ if args.batch_size is not None and args.ngpu > 0:
args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
+ if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
new file mode 100644
index 0000000..be63af1
--- /dev/null
+++ b/funasr/bin/sa_asr_inference.py
@@ -0,0 +1,674 @@
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
+from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.sa_asr import ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+
+class Speech2Text:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, None, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+ logging.info(f"Beam_search: {beam_search}")
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+ ) -> List[
+ Tuple[
+ Optional[str],
+ Optional[str],
+ List[str],
+ List[int],
+ Union[Hypothesis],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, text_id, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if isinstance(profile, np.ndarray):
+ profile = torch.tensor(profile)
+
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ asr_enc, _, spk_enc = self.asr_model.encode(**batch)
+ if isinstance(asr_enc, tuple):
+ asr_enc = asr_enc[0]
+ if isinstance(spk_enc, tuple):
+ spk_enc = spk_enc[0]
+ assert len(asr_enc) == 1, len(asr_enc)
+ assert len(spk_enc) == 1, len(spk_enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1: last_pos]
+ else:
+ token_int = hyp.yseq[1: last_pos].tolist()
+
+ spk_weigths=torch.stack(hyp.spk_weigths, dim=0)
+
+ token_ori = self.converter.ids2tokens(token_int)
+ text_ori = self.tokenizer.tokens2text(token_ori)
+
+ text_ori_spklist = text_ori.split('$')
+ cur_index = 0
+ spk_choose = []
+ for i in range(len(text_ori_spklist)):
+ text_ori_split = text_ori_spklist[i]
+ n = len(text_ori_split)
+ spk_weights_local = spk_weigths[cur_index: cur_index + n]
+ cur_index = cur_index + n + 1
+ spk_weights_local = spk_weights_local.mean(dim=0)
+ spk_choose_local = spk_weights_local.argmax(-1)
+ spk_choose.append(spk_choose_local.item() + 1)
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ text_spklist = text.split('$')
+ assert len(spk_choose) == len(text_spklist)
+
+ spk_list=[]
+ for i in range(len(text_spklist)):
+ text_split = text_spklist[i]
+ n = len(text_split)
+ spk_list.append(str(spk_choose[i]) * n)
+
+ text_id = '$'.join(spk_list)
+
+ assert len(text) == len(text_id)
+
+ results.append((text, text_id, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
+
+def inference(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ penalty=penalty,
+ log_level=log_level,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ raw_inputs=raw_inputs,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ key_file=key_file,
+ word_lm_train_config=word_lm_train_config,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ streaming=streaming,
+ output_dir=output_dir,
+ dtype=dtype,
+ seed=seed,
+ ngram_weight=ngram_weight,
+ nbest=nbest,
+ num_workers=num_workers,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+def inference_modelscope(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=True,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+ ibest_writer["text_id"][key] = text_id
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}".format(text))
+ logging.info("text_id predictions: {}\n".format(text_id))
+ return asr_result_list
+
+ return _forward
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--raw_inputs", type=list, default=None)
+ # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.5,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ inference(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
new file mode 100755
index 0000000..c7c7c42
--- /dev/null
+++ b/funasr/bin/sa_asr_train.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+
+import os
+
+import logging
+
+logging.basicConfig(
+ level='INFO',
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+)
+
+from funasr.tasks.sa_asr import ASRTask
+
+
+# for ASR Training
+def parse_args():
+ parser = ASRTask.get_parser()
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(args=None, cmd=None):
+ # for ASR Training
+ ASRTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ if args.ngpu > 0:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small":
+ if args.batch_size is not None and args.ngpu > 0:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None and args.ngpu > 0:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index dc872b0..d757f7f 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -46,13 +46,15 @@
if self.normalize:
# soundfile.read normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load(
- wav, sr=self.dest_sample_rate, mono=not self.always_2d
+ wav, sr=self.dest_sample_rate, mono=self.always_2d
)
else:
array, rate = librosa.load(
- wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype
+ wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
)
+ if array.ndim==2:
+ array=array.transpose((1, 0))
return rate, array
def get_path(self, key):
diff --git a/funasr/losses/nll_loss.py b/funasr/losses/nll_loss.py
new file mode 100644
index 0000000..7e4e294
--- /dev/null
+++ b/funasr/losses/nll_loss.py
@@ -0,0 +1,47 @@
+import torch
+from torch import nn
+
+class NllLoss(nn.Module):
+ """Nll loss.
+
+ :param int size: the number of class
+ :param int padding_idx: ignored class id
+ :param bool normalize_length: normalize loss by sequence length if True
+ :param torch.nn.Module criterion: loss function
+ """
+
+ def __init__(
+ self,
+ size,
+ padding_idx,
+ normalize_length=False,
+ criterion=nn.NLLLoss(reduction='none'),
+ ):
+ """Construct an LabelSmoothingLoss object."""
+ super(NllLoss, self).__init__()
+ self.criterion = criterion
+ self.padding_idx = padding_idx
+ self.size = size
+ self.true_dist = None
+ self.normalize_length = normalize_length
+
+ def forward(self, x, target):
+ """Compute loss between x and target.
+
+ :param torch.Tensor x: prediction (batch, seqlen, class)
+ :param torch.Tensor target:
+ target signal masked with self.padding_id (batch, seqlen)
+ :return: scalar float value
+ :rtype torch.Tensor
+ """
+ assert x.size(2) == self.size
+ batch_size = x.size(0)
+ x = x.view(-1, self.size)
+ target = target.view(-1)
+ with torch.no_grad():
+ ignore = target == self.padding_idx # (B,)
+ total = len(target) - ignore.sum().item()
+ target = target.masked_fill(ignore, 0) # avoid -1 index
+ kl = self.criterion(x , target)
+ denom = total if self.normalize_length else batch_size
+ return kl.masked_fill(ignore, 0).sum() / denom
diff --git a/funasr/models/decoder/decoder_layer_sa_asr.py b/funasr/models/decoder/decoder_layer_sa_asr.py
new file mode 100644
index 0000000..80afc51
--- /dev/null
+++ b/funasr/models/decoder/decoder_layer_sa_asr.py
@@ -0,0 +1,169 @@
+import torch
+from torch import nn
+
+from funasr.modules.layer_norm import LayerNorm
+
+
+class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ if self.concat_after:
+ tgt_concat = torch.cat(
+ (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+ )
+ x = residual + self.concat_linear1(tgt_concat)
+ else:
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+ if not self.normalize_before:
+ x = self.norm1(x)
+ z = x
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
+
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, skip), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x = residual + self.dropout(skip)
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
+
+class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
+
+ def __init__(
+ self,
+ size,
+ d_size,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
+ self.size = size
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ self.spk_linear = nn.Linear(d_size, size, bias=False)
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ x = tgt_q
+ if self.normalize_before:
+ x = self.norm2(x)
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+ if not self.normalize_before:
+ x = self.norm2(x)
+ residual = x
+
+ if dn!=None:
+ x = x + self.spk_linear(dn)
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
+
+
+
diff --git a/funasr/models/decoder/transformer_decoder_sa_asr.py b/funasr/models/decoder/transformer_decoder_sa_asr.py
new file mode 100644
index 0000000..949f9c8
--- /dev/null
+++ b/funasr/models/decoder/transformer_decoder_sa_asr.py
@@ -0,0 +1,291 @@
+from typing import Any
+from typing import List
+from typing import Sequence
+from typing import Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.attention import CosineDistanceAttention
+from funasr.models.decoder.transformer_decoder import DecoderLayer
+from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
+from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
+from funasr.modules.dynamic_conv import DynamicConvolution
+from funasr.modules.dynamic_conv2d import DynamicConvolution2D
+from funasr.modules.embedding import PositionalEncoding
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.lightconv import LightweightConvolution
+from funasr.modules.lightconv2d import LightweightConvolution2D
+from funasr.modules.mask import subsequent_mask
+from funasr.modules.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.models.decoder.abs_decoder import AbsDecoder
+
+class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ spker_embedding_dim: int = 256,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ input_layer: str = "embed",
+ use_asr_output_layer: bool = True,
+ use_spk_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_asr_output_layer:
+ self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.asr_output_layer = None
+
+ if use_spk_output_layer:
+ self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
+ else:
+ self.spk_output_layer = None
+
+ self.cos_distance_att = CosineDistanceAttention()
+
+ self.decoder1 = None
+ self.decoder2 = None
+ self.decoder3 = None
+ self.decoder4 = None
+
+ def forward(
+ self,
+ asr_hs_pad: torch.Tensor,
+ spk_hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ tgt = ys_in_pad
+ # tgt_mask: (B, 1, L)
+ tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+
+ asr_memory = asr_hs_pad
+ spk_memory = spk_hs_pad
+ memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
+ # Spk decoder
+ x = self.embed(tgt)
+
+ x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
+ x, tgt_mask, asr_memory, spk_memory, memory_mask
+ )
+ x, tgt_mask, spk_memory, memory_mask = self.decoder2(
+ x, tgt_mask, spk_memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.spk_output_layer is not None:
+ x = self.spk_output_layer(x)
+ dn, weights = self.cos_distance_att(x, profile, profile_lens)
+ # Asr decoder
+ x, tgt_mask, asr_memory, memory_mask = self.decoder3(
+ z, tgt_mask, asr_memory, memory_mask, dn
+ )
+ x, tgt_mask, asr_memory, memory_mask = self.decoder4(
+ x, tgt_mask, asr_memory, memory_mask
+ )
+
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.asr_output_layer is not None:
+ x = self.asr_output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, weights, olens
+
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ asr_memory: torch.Tensor,
+ spk_memory: torch.Tensor,
+ profile: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+
+ x = self.embed(tgt)
+
+ if cache is None:
+ cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
+ new_cache = []
+ x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
+ x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
+ )
+ new_cache.append(x)
+ for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
+ x, tgt_mask, spk_memory, _ = decoder(
+ x, tgt_mask, spk_memory, None, cache=c
+ )
+ new_cache.append(x)
+ if self.normalize_before:
+ x = self.after_norm(x)
+ else:
+ x = x
+ if self.spk_output_layer is not None:
+ x = self.spk_output_layer(x)
+ dn, weights = self.cos_distance_att(x, profile, None)
+
+ x, tgt_mask, asr_memory, _ = self.decoder3(
+ z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
+ )
+ new_cache.append(x)
+
+ for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
+ x, tgt_mask, asr_memory, _ = decoder(
+ x, tgt_mask, asr_memory, None, cache=c
+ )
+ new_cache.append(x)
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.asr_output_layer is not None:
+ y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
+
+ return y, weights, new_cache
+
+ def score(self, ys, state, asr_enc, spk_enc, profile):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
+ logp, weights, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
+ )
+ return logp.squeeze(0), weights.squeeze(), state
+
+class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ spker_embedding_dim: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ asr_num_blocks: int = 6,
+ spk_num_blocks: int = 3,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_asr_output_layer: bool = True,
+ use_spk_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ ):
+ assert check_argument_types()
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ spker_embedding_dim=spker_embedding_dim,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_asr_output_layer=use_asr_output_layer,
+ use_spk_output_layer=use_spk_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+
+ self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ self.decoder2 = repeat(
+ spk_num_blocks - 1,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+
+ self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
+ attention_dim,
+ spker_embedding_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ self.decoder4 = repeat(
+ asr_num_blocks - 1,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
new file mode 100644
index 0000000..0d4097e
--- /dev/null
+++ b/funasr/models/e2e_sa_asr.py
@@ -0,0 +1,521 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
+from funasr.losses.nll_loss import NllLoss
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.e2e_asr_common import ErrorCalculator
+from funasr.modules.nets_utils import th_accuracy
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class ESPnetASRModel(AbsESPnetModel):
+ """CTC-attention hybrid Encoder-Decoder model"""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ max_spk_num: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ asr_encoder: AbsEncoder,
+ spk_encoder: torch.nn.Module,
+ postencoder: Optional[AbsPostEncoder],
+ decoder: AbsDecoder,
+ ctc: CTC,
+ spk_weight: float = 0.5,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ ):
+ assert check_argument_types()
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+ super().__init__()
+ # note that eos is the same as sos (equivalent ID)
+ self.blank_id = 0
+ self.sos = 1
+ self.eos = 2
+ self.vocab_size = vocab_size
+ self.max_spk_num=max_spk_num
+ self.ignore_id = ignore_id
+ self.spk_weight = spk_weight
+ self.ctc_weight = ctc_weight
+ self.interctc_weight = interctc_weight
+ self.token_list = token_list.copy()
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.preencoder = preencoder
+ self.postencoder = postencoder
+ self.asr_encoder = asr_encoder
+ self.spk_encoder = spk_encoder
+
+ if not hasattr(self.asr_encoder, "interctc_use_conditioning"):
+ self.asr_encoder.interctc_use_conditioning = False
+ if self.asr_encoder.interctc_use_conditioning:
+ self.asr_encoder.conditioning_layer = torch.nn.Linear(
+ vocab_size, self.asr_encoder.output_size()
+ )
+
+ self.error_calculator = None
+
+
+ # we set self.decoder = None in the CTC mode since
+ # self.decoder parameters were never used and PyTorch complained
+ # and threw an Exception in the multi-GPU experiment.
+ # thanks Jeff Farris for pointing out the issue.
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ self.criterion_spk = NllLoss(
+ size=max_spk_num,
+ padding_idx=ignore_id,
+ normalize_length=length_normalized_loss,
+ )
+
+ if report_cer or report_wer:
+ self.error_calculator = ErrorCalculator(
+ token_list, sym_space, sym_blank, report_cer, report_wer
+ )
+
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+
+ self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lengths: torch.Tensor,
+ text_id: torch.Tensor,
+ text_id_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ profile: (Batch, Length, Dim)
+ profile_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+
+ # 1. Encoder
+ asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(asr_encoder_out, tuple):
+ intermediate_outs = asr_encoder_out[1]
+ asr_encoder_out = asr_encoder_out[0]
+
+ loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ asr_encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+
+ # 2b. Attention decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
+ asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss_asr = loss_att
+ elif self.ctc_weight == 1.0:
+ loss_asr = loss_ctc
+ else:
+ loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+ if self.spk_weight == 0.0:
+ loss = loss_asr
+ else:
+ loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
+
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_asr=loss_asr.detach(),
+ loss_att=loss_att.detach() if loss_att is not None else None,
+ loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
+ loss_spk=loss_spk.detach() if loss_spk is not None else None,
+ acc=acc_att,
+ acc_spk=acc_spk,
+ cer=cer_att,
+ wer=wer_att,
+ cer_ctc=cer_ctc,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, torch.Tensor]:
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+ feats, feats_lengths = speech, speech_lengths
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ feats_raw = feats.clone()
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.asr_encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.asr_encoder(
+ feats, feats_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
+ # import ipdb;ipdb.set_trace()
+ if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
+ encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
+ else:
+ encoder_out_spk=encoder_out_spk_ori
+ # Post-encoder, e.g. NLU
+ if self.postencoder is not None:
+ encoder_out, encoder_out_lens = self.postencoder(
+ encoder_out, encoder_out_lens
+ )
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+ assert encoder_out_spk.size(0) == speech.size(0), (
+ encoder_out_spk.size(),
+ speech.size(0),
+ )
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens, encoder_out_spk
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ def nll(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute negative log likelihood(nll) from transformer-decoder
+
+ Normally, this function is called in batchify_nll.
+
+ Args:
+ encoder_out: (Batch, Length, Dim)
+ encoder_out_lens: (Batch,)
+ ys_pad: (Batch, Length)
+ ys_pad_lens: (Batch,)
+ """
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ ) # [batch, seqlen, dim]
+ batch_size = decoder_out.size(0)
+ decoder_num_class = decoder_out.size(2)
+ # nll: negative log-likelihood
+ nll = torch.nn.functional.cross_entropy(
+ decoder_out.view(-1, decoder_num_class),
+ ys_out_pad.view(-1),
+ ignore_index=self.ignore_id,
+ reduction="none",
+ )
+ nll = nll.view(batch_size, -1)
+ nll = nll.sum(dim=1)
+ assert nll.size(0) == batch_size
+ return nll
+
+ def batchify_nll(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ batch_size: int = 100,
+ ):
+ """Compute negative log likelihood(nll) from transformer-decoder
+
+ To avoid OOM, this fuction seperate the input into batches.
+ Then call nll for each batch and combine and return results.
+ Args:
+ encoder_out: (Batch, Length, Dim)
+ encoder_out_lens: (Batch,)
+ ys_pad: (Batch, Length)
+ ys_pad_lens: (Batch,)
+ batch_size: int, samples each batch contain when computing nll,
+ you may change this to avoid OOM or increase
+ GPU memory usage
+ """
+ total_num = encoder_out.size(0)
+ if total_num <= batch_size:
+ nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+ else:
+ nll = []
+ start_idx = 0
+ while True:
+ end_idx = min(start_idx + batch_size, total_num)
+ batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
+ batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
+ batch_ys_pad = ys_pad[start_idx:end_idx, :]
+ batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
+ batch_nll = self.nll(
+ batch_encoder_out,
+ batch_encoder_out_lens,
+ batch_ys_pad,
+ batch_ys_pad_lens,
+ )
+ nll.append(batch_nll)
+ start_idx = end_idx
+ if start_idx == total_num:
+ break
+ nll = torch.cat(nll)
+ assert nll.size(0) == total_num
+ return nll
+
+ def _calc_att_loss(
+ self,
+ asr_encoder_out: torch.Tensor,
+ spk_encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lens: torch.Tensor,
+ text_id: torch.Tensor,
+ text_id_lengths: torch.Tensor
+ ):
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, weights_no_pad, _ = self.decoder(
+ asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
+ )
+
+ spk_num_no_pad=weights_no_pad.size(-1)
+ pad=(0,self.max_spk_num-spk_num_no_pad)
+ weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
+
+ # pre_id=weights.argmax(-1)
+ # pre_text=decoder_out.argmax(-1)
+ # id_mask=(pre_id==text_id).to(dtype=text_id.dtype)
+ # pre_text_mask=pre_text*id_mask+1-id_mask #鐩稿悓鐨勫湴鏂逛笉鍙橈紝涓嶅悓鐨勫湴鏂硅涓�1(<unk>)
+ # padding_mask= ys_out_pad != self.ignore_id
+ # numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask))
+ # denominator = torch.sum(padding_mask)
+ # sd_acc = float(numerator) / float(denominator)
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ loss_spk = self.criterion_spk(torch.log(weights), text_id)
+
+ acc_spk= th_accuracy(
+ weights.view(-1, self.max_spk_num),
+ text_id,
+ ignore_label=self.ignore_id,
+ )
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 9671fe9..2e1b0c4 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -38,6 +38,7 @@
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
+ use_channel: int = None,
):
assert check_argument_types()
super().__init__()
@@ -77,6 +78,7 @@
)
self.n_mels = n_mels
self.frontend_type = "default"
+ self.use_channel = use_channel
def output_size(self) -> int:
return self.n_mels
@@ -100,9 +102,12 @@
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
- # Select 1ch randomly
- ch = np.random.randint(input_stft.size(2))
- input_stft = input_stft[:, :, ch, :]
+ if self.use_channel == None:
+ input_stft = input_stft[:, :, 0, :]
+ else:
+ # Select 1ch randomly
+ ch = np.random.randint(input_stft.size(2))
+ input_stft = input_stft[:, :, ch, :]
else:
# Use the first channel
input_stft = input_stft[:, :, 0, :]
diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py
index 8f85de9..39d94be 100644
--- a/funasr/models/pooling/statistic_pooling.py
+++ b/funasr/models/pooling/statistic_pooling.py
@@ -83,9 +83,9 @@
num_chunk = int(math.ceil(tt / pooling_stride))
pad = pooling_size // 2
if len(xs_pad.shape) == 4:
- features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
+ features = F.pad(xs_pad, (0, 0, pad, pad), "replicate")
else:
- features = F.pad(xs_pad, (pad, pad), "reflect")
+ features = F.pad(xs_pad, (pad, pad), "replicate")
stat_list = []
for i in range(num_chunk):
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 6202079..fcb3ed4 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -13,6 +13,9 @@
from torch import nn
from typing import Optional, Tuple
+import torch.nn.functional as F
+from funasr.modules.nets_utils import make_pad_mask
+
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -959,3 +962,37 @@
q, k, v = self.forward_qkv(query, key, value)
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
+
+class CosineDistanceAttention(nn.Module):
+ """ Compute Cosine Distance between spk decoder output and speaker profile
+ Args:
+ profile_path: speaker profile file path (.npy file)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, spk_decoder_out, profile, profile_lens=None):
+ """
+ Args:
+ spk_decoder_out(torch.Tensor):(B, L, D)
+ spk_profiles(torch.Tensor):(B, N, D)
+ """
+ x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
+ if profile_lens is not None:
+
+ mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+ )
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+ weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
+ else:
+ x = x[:, -1:, :, :]
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+ weights = self.softmax(weights_not_softmax) # (B, 1, N)
+ spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
+
+ return spk_embedding, weights
diff --git a/funasr/modules/beam_search/beam_search_sa_asr.py b/funasr/modules/beam_search/beam_search_sa_asr.py
new file mode 100755
index 0000000..b2b6833
--- /dev/null
+++ b/funasr/modules/beam_search/beam_search_sa_asr.py
@@ -0,0 +1,525 @@
+"""Beam search module."""
+
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.modules.e2e_asr_common import end_detect
+from funasr.modules.scorers.scorer_interface import PartialScorerInterface
+from funasr.modules.scorers.scorer_interface import ScorerInterface
+from funasr.models.decoder.abs_decoder import AbsDecoder
+
+
+class Hypothesis(NamedTuple):
+ """Hypothesis data type."""
+
+ yseq: torch.Tensor
+ spk_weigths : List
+ score: Union[float, torch.Tensor] = 0
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
+ states: Dict[str, Any] = dict()
+
+ def asdict(self) -> dict:
+ """Convert data to JSON-friendly dict."""
+ return self._replace(
+ yseq=self.yseq.tolist(),
+ score=float(self.score),
+ scores={k: float(v) for k, v in self.scores.items()},
+ )._asdict()
+
+
+class BeamSearch(torch.nn.Module):
+ """Beam search implementation."""
+
+ def __init__(
+ self,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ beam_size: int,
+ vocab_size: int,
+ sos: int,
+ eos: int,
+ token_list: List[str] = None,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = None,
+ ):
+ """Initialize beam search.
+
+ Args:
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ token_list (list[str]): List of tokens for debug log
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ """
+ super().__init__()
+ # set scorers
+ self.weights = weights
+ self.scorers = dict()
+ self.full_scorers = dict()
+ self.part_scorers = dict()
+ # this module dict is required for recursive cast
+ # `self.to(device, dtype)` in `recog.py`
+ self.nn_dict = torch.nn.ModuleDict()
+ for k, v in scorers.items():
+ w = weights.get(k, 0)
+ if w == 0 or v is None:
+ continue
+ assert isinstance(
+ v, ScorerInterface
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
+ self.scorers[k] = v
+ if isinstance(v, PartialScorerInterface):
+ self.part_scorers[k] = v
+ else:
+ self.full_scorers[k] = v
+ if isinstance(v, torch.nn.Module):
+ self.nn_dict[k] = v
+
+ # set configurations
+ self.sos = sos
+ self.eos = eos
+ self.token_list = token_list
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
+ self.beam_size = beam_size
+ self.n_vocab = vocab_size
+ if (
+ pre_beam_score_key is not None
+ and pre_beam_score_key != "full"
+ and pre_beam_score_key not in self.full_scorers
+ ):
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+ self.pre_beam_score_key = pre_beam_score_key
+ self.do_pre_beam = (
+ self.pre_beam_score_key is not None
+ and self.pre_beam_size < self.n_vocab
+ and len(self.part_scorers) > 0
+ )
+
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+ """Get an initial hypothesis data.
+
+ Args:
+ x (torch.Tensor): The encoder output feature
+
+ Returns:
+ Hypothesis: The initial hypothesis.
+
+ """
+ init_states = dict()
+ init_scores = dict()
+ for k, d in self.scorers.items():
+ init_states[k] = d.init_state(x)
+ init_scores[k] = 0.0
+ return [
+ Hypothesis(
+ score=0.0,
+ scores=init_scores,
+ states=init_states,
+ yseq=torch.tensor([self.sos], device=x.device),
+ spk_weigths=[],
+ )
+ ]
+
+ @staticmethod
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+ """Append new token to prefix tokens.
+
+ Args:
+ xs (torch.Tensor): The prefix token
+ x (int): The new token to append
+
+ Returns:
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+ """
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+ return torch.cat((xs, x))
+
+ def score_full(
+ self, hyp: Hypothesis, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.full_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.full_scorers`
+ and tensor score values of shape: `(self.n_vocab,)`,
+ and state dict that has string keys
+ and state values of `self.full_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.full_scorers.items():
+ if isinstance(d, AbsDecoder):
+ scores[k], spk_weigths, states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc, spk_enc, profile)
+ else:
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc)
+ return scores, spk_weigths, states
+
+ def score_partial(
+ self, hyp: Hypothesis, ids: torch.Tensor, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.part_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.part_scorers`
+ and tensor score values of shape: `(len(ids),)`,
+ and state dict that has string keys
+ and state values of `self.part_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.part_scorers.items():
+ if isinstance(d, AbsDecoder):
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc, spk_enc, profile)
+ else:
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc)
+ return scores, states
+
+ def beam(
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute topk full token ids and partial token ids.
+
+ Args:
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+ Its shape is `(self.n_vocab,)`.
+ ids (torch.Tensor): The partial token ids to compute topk
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ The topk full token ids and partial token ids.
+ Their shapes are `(self.beam_size,)`
+
+ """
+ # no pre beam performed
+ if weighted_scores.size(0) == ids.size(0):
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ return top_ids, top_ids
+
+ # mask pruned in pre-beam not to select in topk
+ tmp = weighted_scores[ids]
+ weighted_scores[:] = -float("inf")
+ weighted_scores[ids] = tmp
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+ return top_ids, local_ids
+
+ @staticmethod
+ def merge_scores(
+ prev_scores: Dict[str, float],
+ next_full_scores: Dict[str, torch.Tensor],
+ full_idx: int,
+ next_part_scores: Dict[str, torch.Tensor],
+ part_idx: int,
+ ) -> Dict[str, torch.Tensor]:
+ """Merge scores for new hypothesis.
+
+ Args:
+ prev_scores (Dict[str, float]):
+ The previous hypothesis scores by `self.scorers`
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+ full_idx (int): The next token id for `next_full_scores`
+ next_part_scores (Dict[str, torch.Tensor]):
+ scores of partial tokens by `self.part_scorers`
+ part_idx (int): The new token id for `next_part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are scalar tensors by the scorers.
+
+ """
+ new_scores = dict()
+ for k, v in next_full_scores.items():
+ new_scores[k] = prev_scores[k] + v[full_idx]
+ for k, v in next_part_scores.items():
+ new_scores[k] = prev_scores[k] + v[part_idx]
+ return new_scores
+
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+ """Merge states for new hypothesis.
+
+ Args:
+ states: states of `self.full_scorers`
+ part_states: states of `self.part_scorers`
+ part_idx (int): The new token id for `part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are states of the scorers.
+
+ """
+ new_states = dict()
+ for k, v in states.items():
+ new_states[k] = v
+ for k, d in self.part_scorers.items():
+ new_states[k] = d.select_state(part_states[k], part_idx)
+ return new_states
+
+ def search(
+ self, running_hyps: List[Hypothesis], asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor
+ ) -> List[Hypothesis]:
+ """Search new tokens for running hypotheses and encoded speech x.
+
+ Args:
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
+ x (torch.Tensor): Encoded speech feature (T, D)
+
+ Returns:
+ List[Hypotheses]: Best sorted hypotheses
+
+ """
+ # import ipdb;ipdb.set_trace()
+ best_hyps = []
+ part_ids = torch.arange(self.n_vocab, device=asr_enc.device) # no pre-beam
+ for hyp in running_hyps:
+ # scoring
+ weighted_scores = torch.zeros(self.n_vocab, dtype=asr_enc.dtype, device=asr_enc.device)
+ scores, spk_weigths, states = self.score_full(hyp, asr_enc, spk_enc, profile)
+ for k in self.full_scorers:
+ weighted_scores += self.weights[k] * scores[k]
+ # partial scoring
+ if self.do_pre_beam:
+ pre_beam_scores = (
+ weighted_scores
+ if self.pre_beam_score_key == "full"
+ else scores[self.pre_beam_score_key]
+ )
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+ part_scores, part_states = self.score_partial(hyp, part_ids, asr_enc, spk_enc, profile)
+ for k in self.part_scorers:
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+ # add previous hyp score
+ weighted_scores += hyp.score
+
+ # update hyps
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+ # will be (2 x beam at most)
+ best_hyps.append(
+ Hypothesis(
+ score=weighted_scores[j],
+ yseq=self.append_token(hyp.yseq, j),
+ scores=self.merge_scores(
+ hyp.scores, scores, j, part_scores, part_j
+ ),
+ states=self.merge_states(states, part_states, part_j),
+ spk_weigths=hyp.spk_weigths+[spk_weigths],
+ )
+ )
+
+ # sort and prune 2 x beam -> beam
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+ : min(len(best_hyps), self.beam_size)
+ ]
+ return best_hyps
+
+ def forward(
+ self, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ minlenratio (float): Input length ratio to obtain min output length.
+
+ Returns:
+ list[Hypothesis]: N-best decoding results
+
+ """
+ # import ipdb;ipdb.set_trace()
+ # set length bounds
+ if maxlenratio == 0:
+ maxlen = asr_enc.shape[0]
+ else:
+ maxlen = max(1, int(maxlenratio * asr_enc.size(0)))
+ minlen = int(minlenratio * asr_enc.size(0))
+ logging.info("decoder input length: " + str(asr_enc.shape[0]))
+ logging.info("max output length: " + str(maxlen))
+ logging.info("min output length: " + str(minlen))
+
+ # main loop of prefix search
+ running_hyps = self.init_hyp(asr_enc)
+ ended_hyps = []
+ for i in range(maxlen):
+ logging.debug("position " + str(i))
+ best = self.search(running_hyps, asr_enc, spk_enc, profile)
+ #import pdb;pdb.set_trace()
+ # post process of one iteration
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+ # end detection
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+ logging.info(f"end detected at {i}")
+ break
+ if len(running_hyps) == 0:
+ logging.info("no hypothesis. Finish decoding.")
+ break
+ else:
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+ # check the number of hypotheses reaching to eos
+ if len(nbest_hyps) == 0:
+ logging.warning(
+ "there is no N-best results, perform recognition "
+ "again with smaller minlenratio."
+ )
+ return (
+ []
+ if minlenratio < 0.1
+ else self.forward(asr_enc, spk_enc, profile, maxlenratio, max(0.0, minlenratio - 0.1))
+ )
+
+ # report the best result
+ best = nbest_hyps[0]
+ for k, v in best.scores.items():
+ logging.info(
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+ )
+ logging.info(f"total log probability: {best.score:.2f}")
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+ if self.token_list is not None:
+ logging.info(
+ "best hypo: "
+ + "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ + "\n"
+ )
+ return nbest_hyps
+
+ def post_process(
+ self,
+ i: int,
+ maxlen: int,
+ maxlenratio: float,
+ running_hyps: List[Hypothesis],
+ ended_hyps: List[Hypothesis],
+ ) -> List[Hypothesis]:
+ """Perform post-processing of beam search iterations.
+
+ Args:
+ i (int): The length of hypothesis tokens.
+ maxlen (int): The maximum length of tokens in beam search.
+ maxlenratio (int): The maximum length ratio in beam search.
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+ Returns:
+ List[Hypothesis]: The new running hypotheses.
+
+ """
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+ if self.token_list is not None:
+ logging.debug(
+ "best hypo: "
+ + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+ )
+ # add eos in the final loop to avoid that there are no ended hyps
+ if i == maxlen - 1:
+ logging.info("adding <eos> in the last position in the loop")
+ running_hyps = [
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
+ for h in running_hyps
+ ]
+
+ # add ended hypotheses to a final list, and removed them from current hypotheses
+ # (this will be a problem, number of hyps < beam)
+ remained_hyps = []
+ for hyp in running_hyps:
+ if hyp.yseq[-1] == self.eos:
+ # e.g., Word LM needs to add final <eos> score
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+ s = d.final_score(hyp.states[k])
+ hyp.scores[k] += s
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+ ended_hyps.append(hyp)
+ else:
+ remained_hyps.append(hyp)
+ return remained_hyps
+
+
+def beam_search(
+ x: torch.Tensor,
+ sos: int,
+ eos: int,
+ beam_size: int,
+ vocab_size: int,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ token_list: List[str] = None,
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = "full",
+) -> list:
+ """Perform beam search with scorers.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ token_list (list[str]): List of tokens for debug log
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ minlenratio (float): Input length ratio to obtain min output length.
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ Returns:
+ list: N-best decoding results
+
+ """
+ ret = BeamSearch(
+ scorers,
+ weights,
+ beam_size=beam_size,
+ vocab_size=vocab_size,
+ pre_beam_ratio=pre_beam_ratio,
+ pre_beam_score_key=pre_beam_score_key,
+ sos=sos,
+ eos=eos,
+ token_list=token_list,
+ ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
+ return [h.asdict() for h in ret]
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 3d2004c..f8c1009 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -445,6 +445,12 @@
help='Perform on "collect stats" mode',
)
group.add_argument(
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+ group.add_argument(
"--write_collected_feats",
type=str2bool,
default=False,
@@ -635,8 +641,8 @@
group.add_argument(
"--init_param",
type=str,
+ action="append",
default=[],
- nargs="*",
help="Specify the file path used for initialization of parameters. "
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
"where file_path is the model file path, "
@@ -662,7 +668,7 @@
"--freeze_param",
type=str,
default=[],
- nargs="*",
+ action="append",
help="Freeze parameters",
)
@@ -1153,10 +1159,10 @@
elif args.distributed and args.simple_ddp:
distributed_option.init_torch_distributed_pai(args)
args.ngpu = dist.get_world_size()
- if args.dataset_type == "small":
+ if args.dataset_type == "small" and args.ngpu > 0:
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
+ if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
# filter samples if wav.scp and text are mismatch
@@ -1316,6 +1322,7 @@
data_path_and_name_and_type=args.train_data_path_and_name_and_type,
key_file=train_key_file,
batch_size=args.batch_size,
+ mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,
@@ -1327,6 +1334,7 @@
data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
key_file=valid_key_file,
batch_size=args.valid_batch_size,
+ mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
new file mode 100644
index 0000000..738ec52
--- /dev/null
+++ b/funasr/tasks/sa_asr.py
@@ -0,0 +1,623 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+import yaml
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
+from funasr.models.decoder.transformer_decoder import (
+ DynamicConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
+from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolutionTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.e2e_sa_asr import ESPnetASRModel
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
+from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34,ResNet34Diar
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+ HuggingFaceTransformersPostEncoder, # noqa: H301
+)
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ multichannelfrontend=MultiChannelFrontend,
+ ),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ asr=ESPnetASRModel,
+ uniasr=UniASR,
+ paraformer=Paraformer,
+ paraformer_bert=ParaformerBert,
+ bicif_paraformer=BiCifParaformer,
+ contextual_paraformer=ContextualParaformer,
+ mfcca=MFCCA,
+ timestamp_prediction=TimestampPredictor,
+ ),
+ type_check=AbsESPnetModel,
+ default="asr",
+)
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ linear=LinearProjection,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
+)
+
+encoder_choices2 = ClassChoices(
+ "encoder2",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+postencoder_choices = ClassChoices(
+ name="postencoder",
+ classes=dict(
+ hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+ ),
+ type_check=AbsPostEncoder,
+ default=None,
+ optional=True,
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="sa_decoder",
+)
+decoder_choices2 = ClassChoices(
+ "decoder2",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="rnn",
+)
+predictor_choices = ClassChoices(
+ name="predictor",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ cif_predictor_v3=CifPredictorV3,
+ ),
+ type_check=None,
+ default="cif_predictor",
+ optional=True,
+)
+predictor_choices2 = ClassChoices(
+ name="predictor2",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ ),
+ type_check=None,
+ default="cif_predictor",
+ optional=True,
+)
+stride_conv_choices = ClassChoices(
+ name="stride_conv",
+ classes=dict(
+ stride_conv1d=Conv1dSubsampling
+ ),
+ type_check=None,
+ default="stride_conv1d",
+ optional=True,
+)
+
+
+class ASRTask(AbsTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
+ # --postencoder and --postencoder_conf
+ postencoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(description="Task related")
+
+ # NOTE(kamo): add_arguments(..., required=True) can't be used
+ # to provide --print_config mode. Instead of it, do as
+ # required = parser.get_default("required")
+ # required += ["token_list"]
+
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--max_spk_num",
+ type=int_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--seg_dict_file",
+ type=str,
+ default=None,
+ help="seg_dict_file for text processing",
+ )
+ group.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ group.add_argument(
+ "--ctc_conf",
+ action=NestedDictAction,
+ default=get_default_kwargs(CTC),
+ help="The keyword arguments for CTC class.",
+ )
+ group.add_argument(
+ "--joint_net_conf",
+ action=NestedDictAction,
+ default=None,
+ help="The keyword arguments for joint network class.",
+ )
+
+ group = parser.add_argument_group(description="Preprocess related")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word", "phn"],
+ help="The text will be tokenized " "in the specified level token",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model file of sentencepiece",
+ )
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ default=None,
+ help="non_linguistic_symbols file path",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Scale the maximum amplitude to the given value.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of rir scp file.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="THe probability for applying RIR convolution.",
+ )
+ parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability applying Noise adding.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of noise decibel level.",
+ )
+
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ assert check_argument_types()
+ # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
+ return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ assert check_argument_types()
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+ # NOTE(kamo): Check attribute existence for backward compatibility
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ if not inference:
+ retval = ("speech", "text")
+ else:
+ # Recognition mode
+ retval = ("speech",)
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ retval = ()
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+
+ # 6. Post-encoder block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ asr_encoder_output_size = asr_encoder.output_size()
+ if getattr(args, "postencoder", None) is not None:
+ postencoder_class = postencoder_choices.get_class(args.postencoder)
+ postencoder = postencoder_class(
+ input_size=asr_encoder_output_size, **args.postencoder_conf
+ )
+ asr_encoder_output_size = postencoder.output_size()
+ else:
+ postencoder = None
+
+ # 7. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder_output_size,
+ **args.decoder_conf,
+ )
+
+ # 8. CTC
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder_output_size, **args.ctc_conf
+ )
+
+ max_spk_num=int(args.max_spk_num)
+
+ # import ipdb;ipdb.set_trace()
+ # 9. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+ model = model_class(
+ vocab_size=vocab_size,
+ max_spk_num=max_spk_num,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index b607e1d..014a79f 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -106,18 +106,17 @@
if num in abbr_begin:
if time_stamp is not None:
begin = time_stamp[ts_nums[num]][0]
- abbr_word = words[num].upper()
+ word_lists.append(words[num].upper())
num += 1
while num < words_size:
if num in abbr_end:
- abbr_word += words[num].upper()
+ word_lists.append(words[num].upper())
last_num = num
break
else:
if words[num].encode('utf-8').isalpha():
- abbr_word += words[num].upper()
+ word_lists.append(words[num].upper())
num += 1
- word_lists.append(abbr_word)
if time_stamp is not None:
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
diff --git a/setup.py b/setup.py
index e837637..ea55606 100644
--- a/setup.py
+++ b/setup.py
@@ -13,7 +13,7 @@
"install": [
"setuptools>=38.5.1",
# "configargparse>=1.2.1",
- "typeguard<=2.13.3",
+ "typeguard==2.13.3",
"humanfriendly",
"scipy>=1.4.1",
# "filelock",
@@ -42,7 +42,10 @@
"oss2",
# "kaldi-native-fbank",
# timestamp
- "edit-distance"
+ "edit-distance",
+ # textgrid
+ "textgrid",
+ "protobuf==3.20.0",
],
# train: The modules invoked when training only.
"train": [
--
Gitblit v1.9.1