游雁
2024-02-19 1448e021accfdb03a381651cb5a8be6d1a6e8adf
aishell example
7个文件已修改
21个文件已添加
1个文件已删除
2542 ■■■■■ 已修改文件
data/list/text.txt 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
data/list/wav.scp 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml 98 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/finetune.sh 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/local/aishell_data_prep.sh 66 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/local/download_and_untar.sh 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/run.sh 203 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/extract_embeds.py 47 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/filter_scp.pl 87 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/fix_data.sh 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/fix_data_feat.sh 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/parse_options.sh 97 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/shuffle_list.pl 44 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/split_scp.pl 246 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/text2token.py 135 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/text_tokenize.py 106 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/text_tokenize.sh 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/utils/textnorm_zh.py 834 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/run.sh 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/finetune.sh 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/compute_audio_cmvn.py 123 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/datasets.py 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/scp2jsonl.py 94 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/download/download_dataset_from_hub.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/model_hf/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/register.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 41 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
data/list/text.txt
New file
@@ -0,0 +1,2 @@
ID0012W0013 当客户风险承受能力评估依据发生变化时
ID0012W0014 杨涛不得不将工厂关掉
data/list/wav.scp
New file
@@ -0,0 +1,2 @@
ID0012W0013 /Users/zhifu/funasr_github/test_local/aishell2_dev_ios/wav/D0012/ID0012W0013.wav
ID0012W0014 /Users/zhifu/funasr_github/test_local/aishell2_dev_ios/wav/D0012/ID0012W0014.wav
examples/aishell/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
@@ -1,6 +1,6 @@
# network architecture
model: funasr.cli.models.paraformer:Paraformer
model: Paraformer
model_conf:
    ctc_weight: 0.3
    lsm_weight: 0.1
@@ -9,9 +9,8 @@
    sampling_ratio: 0.4
    use_1st_decoder_loss: true
# encoder related
encoder: conformer
# encoder
encoder: ConformerEncoder
encoder_conf:
    output_size: 256    # dimension of attention
    attention_heads: 4
@@ -29,8 +28,8 @@
    use_cnn_module: true
    cnn_module_kernel: 15
# decoder related
decoder: paraformer_decoder_san
# decoder
decoder: ParaformerSANDecoder
decoder_conf:
    attention_heads: 4
    linear_units: 2048
@@ -40,8 +39,17 @@
    self_attention_dropout_rate: 0.0
    src_attention_dropout_rate: 0.0
# predictor
predictor: CifPredictor
predictor_conf:
    idim: 256
    threshold: 1.0
    l_order: 1
    r_order: 1
    tail_threshold: 0.45
# frontend related
frontend: wav_frontend
frontend: WavFrontend
frontend_conf:
    fs: 16000
    window: hamming
@@ -51,29 +59,7 @@
    lfr_m: 1
    lfr_n: 1
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  val_scheduler_criterion:
      - valid
      - acc
  best_model_criterion:
  -   - valid
      - acc
      - max
  keep_nbest_models: 10
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
specaug: specaug
specaug: SpecAug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
@@ -89,25 +75,43 @@
    - 40
    num_time_mask: 2
predictor: cif_predictor
predictor_conf:
    idim: 256
    threshold: 1.0
    l_order: 1
    r_order: 1
    tail_threshold: 0.45
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  keep_nbest_models: 10
  avg_nbest_model: 5
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    data_names: speech,text
    data_types: sound,text
    index_ds: IndexDSJsonl
    batch_sampler: RankFullLocalShuffleBatchSampler
    batch_type: example # example or length
    batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 1024
    shuffle: True
    shuffle_conf:
        shuffle_size: 2048
        sort_size: 500
    batch_conf:
        batch_type: example
        batch_size: 2
    num_workers: 8
    num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
  split_with_space: true
normalize: null
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null
examples/aishell/finetune.sh
New file
@@ -0,0 +1,9 @@
cmd="funasr/bin/train.py"
python $cmd \
--config-path "/Users/zhifu/funasr_github/test_local/funasr_cli_egs" \
--config-name "config.yaml" \
++token_list="/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
++train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
++output_dir="/nfs/zhifu.gzf/ckpt/funasr2/exp1"
examples/aishell/paraformer/local/aishell_data_prep.sh
New file
@@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2017 Xingyu Na
# Apache 2.0
#. ./path.sh || exit 1;
if [ $# != 3 ]; then
  echo "Usage: $0 <audio-path> <text-path> <output-path>"
  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
  exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
output_dir=$3
train_dir=$output_dir/data/local/train
dev_dir=$output_dir/data/local/dev
test_dir=$output_dir/data/local/test
tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
mkdir -p $test_dir
mkdir -p $tmp_dir
# data directory check
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
  echo "Error: $0 requires two directory arguments"
  exit 1;
fi
# find wav audio file for train, dev and test resp.
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
n=`cat $tmp_dir/wav.flist | wc -l`
[ $n -ne 141925 ] && \
  echo Warning: expected 141925 data data files, found $n
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
rm -r $tmp_dir
# Transcriptions preparation
for dir in $train_dir $dev_dir $test_dir; do
  echo Preparing $dir transcriptions
  sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
  paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
  utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
  awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
  utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
  sort -u $dir/transcripts.txt > $dir/text
done
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
exit 0;
examples/aishell/paraformer/local/download_and_untar.sh
New file
@@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
#             2017  Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
  remove_archive=true
  shift
fi
if [ $# -ne 3 ]; then
  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
  echo "$0: no such directory $data"
  exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
  if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
  exit 1;
fi
if [ -z "$url" ]; then
  echo "$0: empty URL base."
  exit 1;
fi
if [ -f $data/$part/.complete ]; then
  echo "$0: data part $part was already successfully extracted, nothing to do."
  exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
  size_ok=false
  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
  if ! $size_ok; then
    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
    echo "does not equal the size of one of the archives."
    rm $data/$part.tgz
  else
    echo "$data/$part.tgz exists and appears to be complete."
  fi
fi
if [ ! -f $data/$part.tgz ]; then
  if ! command -v wget >/dev/null; then
    echo "$0: wget is not installed."
    exit 1;
  fi
  full_url=$url/$part.tgz
  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
  cd $data || exit 1
  if ! wget --no-check-certificate $full_url; then
    echo "$0: error executing wget $full_url"
    exit 1;
  fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
  echo "$0: error un-tarring archive $data/$part.tgz"
  exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
  cd $data/$part/wav || exit 1
  for wav in ./*.tar.gz; do
    echo "Extracting wav from $wav"
    tar -zxf $wav && rm $wav
  done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
  rm $data/$part.tgz
fi
exit 0;
examples/aishell/paraformer/run.sh
New file
@@ -0,0 +1,203 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0,1"
gpu_num=2
gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=1
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
token_type=char
stage=0
stop_stage=5
# feature configuration
nj=64
# data
raw_data=../raw_data
data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
asr_config=conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
#inference_config=conf/decode_asr_transformer_noctc_1best.yaml
#inference_asr_model=valid.acc.ave_10best.pb
## you can set gpu num for decoding here
#gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
#ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
#
#if ${gpu_inference}; then
#    inference_nj=$[${ngpu}*${njob}]
#    _ngpu=1
#else
#    inference_nj=$njob
#    _ngpu=0
#fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
    echo "stage -1: Data Download"
    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    echo "stage 0: Data preparation"
    # Data preparation
    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
    for x in train dev test; do
        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
            > ${feats_dir}/data/${x}/text
        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
        python funasr/datasets/audio_datasets/scp2jsonl.py \
        ++scp_file_list='["${feats_dir}/data/${x}/wav.scp", "${feats_dir}/data/${x}/text"]' \
        ++data_type_list='["source", "target"]' \
        ++jsonl_file_out=${feats_dir}/data/${x}/audio_datasets.jsonl
    done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Feature and CMVN Generation"
#    utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
    python funasr/bin/compute_audio_cmvn.py \
    --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
    --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
    ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
    ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \
    ++dataset_conf.num_workers=$nj
fi
token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    echo "stage 2: Dictionary Preparation"
    mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
    echo "make a dictionary"
    echo "<blank>" > ${token_list}
    echo "<s>" >> ${token_list}
    echo "</s>" >> ${token_list}
    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
    echo "<unk>" >> ${token_list}
fi
# LM Training Stage
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: LM Training"
fi
# ASR Training Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: ASR Training"
torchrun \
--nnodes 1 \
--nproc_per_node ${gpu_num} \
funasr/bin/train.py \
--config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
++cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \
++token_list="${token_list}" \
++output_dir="${exp_dir}/exp/${model_dir}"
fi
#
## Testing Stage
#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
#    echo "stage 5: Inference"
#    for dset in ${test_sets}; do
#        asr_exp=${exp_dir}/exp/${model_dir}
#        inference_tag="$(basename "${inference_config}" .yaml)"
#        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
#        _logdir="${_dir}/logdir"
#        if [ -d ${_dir} ]; then
#            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
#            exit 0
#        fi
#        mkdir -p "${_logdir}"
#        _data="${feats_dir}/data/${dset}"
#        key_file=${_data}/${scp}
#        num_scp_file="$(<${key_file} wc -l)"
#        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
#        split_scps=
#        for n in $(seq "${_nj}"); do
#            split_scps+=" ${_logdir}/keys.${n}.scp"
#        done
#        # shellcheck disable=SC2086
#        utils/split_scp.pl "${key_file}" ${split_scps}
#        _opts=
#        if [ -n "${inference_config}" ]; then
#            _opts+="--config ${inference_config} "
#        fi
#        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
#            python -m funasr.bin.asr_inference_launch \
#                --batch_size 1 \
#                --ngpu "${_ngpu}" \
#                --njob ${njob} \
#                --gpuid_list ${gpuid_list} \
#                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
#                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
#                --key_file "${_logdir}"/keys.JOB.scp \
#                --asr_train_config "${asr_exp}"/config.yaml \
#                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
#                --output_dir "${_logdir}"/output.JOB \
#                --mode paraformer \
#                ${_opts}
#
#        for f in token token_int score text; do
#            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
#                for i in $(seq "${_nj}"); do
#                    cat "${_logdir}/output.${i}/1best_recog/${f}"
#                done | sort -k1 >"${_dir}/${f}"
#            fi
#        done
#        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
#        python utils/proce_text.py ${_data}/text ${_data}/text.proc
#        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
#        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
#        cat ${_dir}/text.cer.txt
#    done
#fi
#
## Prepare files for ModelScope fine-tuning and inference
#if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
#    echo "stage 6: ModelScope Preparation"
#    cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
#    vocab_size=$(cat ${token_list} | wc -l)
#    python utils/gen_modelscope_configuration.py \
#        --am_model_name $inference_asr_model \
#        --mode paraformer \
#        --model_name paraformer \
#        --dataset aishell \
#        --output_dir $exp_dir/exp/$model_dir \
#        --vocab_size $vocab_size \
#        --nat _nat \
#        --tag $tag
#fi
examples/aishell/paraformer/utils/extract_embeds.py
New file
@@ -0,0 +1,47 @@
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
import sys
import os
import torch
from kaldiio import WriteHelper
import re
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
out_shape = sys.argv[4]
device = int(sys.argv[5])
model_path = sys.argv[6]
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
with open(text_file_json, 'r') as f:
    js = f.readlines()
f_shape = open(out_shape, "w")
with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
    with torch.no_grad():
        for idx, line in enumerate(js):
            id, tokens = line.strip().split(" ", 1)
            tokens = re.sub(" ", "", tokens.strip())
            tokens = ' '.join([j for j in tokens])
            token_num = len(tokens.split(" "))
            outputs = extractor(tokens)
            outputs = np.array(outputs)
            embeds = outputs[0, 1:-1, :]
            token_num_embeds, dim = embeds.shape
            if token_num == token_num_embeds:
                writer(id, embeds)
                shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
                f_shape.write(shape_line)
            else:
                print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
f_shape.close()
examples/aishell/paraformer/utils/filter_scp.pl
New file
@@ -0,0 +1,87 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
#                     Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
  $shifted=0;
  if ($ARGV[0] eq "--exclude") {
    $exclude = 1;
    shift @ARGV;
    $shifted=1;
  }
  if ($ARGV[0] eq "-f") {
    $field = $ARGV[1];
    shift @ARGV; shift @ARGV;
    $shifted=1
  }
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
  die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
      "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
      "Note: only the first field of each line in id_list matters.  With --exclude, prints\n" .
      "only the lines that were *not* in id_list.\n" .
      "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
      "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
      "-f option, add 1 to the argument.\n" .
      "See also: scripts/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
  @A = split;
  @A>=1 || die "Invalid id-list file line $_";
  $seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
  while(<>) {
    $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
    # $1 is what we filter on.
    if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
      print $_;
    }
  }
} else {
  while(<>) {
    @A = split;
    @A > 0 || die "Invalid scp file line $_";
    @A >= $field || die "Invalid scp file line $_";
    if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
      print $_;
    }
  }
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2)
examples/aishell/paraformer/utils/fix_data.sh
New file
@@ -0,0 +1,35 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/wav.scp ]; then
  echo "$0: wav.scp is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
  echo "$0: text is not found"
  exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp
cp ${data_dir}/text ${data_dir}/.backup/text
mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
rm ${data_dir}/wav.scp.bak
rm ${data_dir}/text.bak
examples/aishell/paraformer/utils/fix_data_feat.sh
New file
@@ -0,0 +1,52 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/feats.scp ]; then
  echo "$0: feats.scp is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
  echo "$0: text is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/speech_shape ]; then
  echo "$0: feature lengths is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/text_shape ]; then
  echo "$0: text lengths is not found"
  exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp
cp ${data_dir}/text ${data_dir}/.backup/text
cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape
cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape
mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape
rm ${data_dir}/feats.scp.bak
rm ${data_dir}/text.bak
rm ${data_dir}/speech_shape.bak
rm ${data_dir}/text_shape.bak
examples/aishell/paraformer/utils/parse_options.sh
New file
@@ -0,0 +1,97 @@
#!/usr/bin/env bash
# Copyright 2012  Johns Hopkins University (Author: Daniel Povey);
#                 Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
  if [ "${!argpos}" == "--config" ]; then
    argpos_plus1=$((argpos+1))
    config=${!argpos_plus1}
    [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
    . $config  # source the config file.
  fi
done
###
### Now we process the command line options
###
while true; do
  [ -z "${1:-}" ] && break;  # break if there are no arguments
  case "$1" in
    # If the enclosing script is called with --help option, print the help
    # message and exit.  Scripts should put help messages in $help_message
    --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
      else printf "$help_message\n" 1>&2 ; fi;
      exit 0 ;;
    --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
      exit 1 ;;
    # If the first command-line argument begins with "--" (e.g. --foo-bar),
    # then work out the variable name as $name, which will equal "foo_bar".
    --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
      # Next we test whether the variable in question is undefned-- if so it's
      # an invalid option and we die.  Note: $0 evaluates to the name of the
      # enclosing script.
      # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
      # is undefined.  We then have to wrap this test inside "eval" because
      # foo_bar is itself inside a variable ($name).
      eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
      oldval="`eval echo \\$$name`";
      # Work out whether we seem to be expecting a Boolean argument.
      if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
        was_bool=true;
      else
        was_bool=false;
      fi
      # Set the variable to the right value-- the escaped quotes make it work if
      # the option had spaces, like --cmd "queue.pl -sync y"
      eval $name=\"$2\";
      # Check that Boolean-valued arguments are really Boolean.
      if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
        echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
        exit 1;
      fi
      shift 2;
      ;;
  *) break;
  esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
examples/aishell/paraformer/utils/shuffle_list.pl
New file
@@ -0,0 +1,44 @@
#!/usr/bin/env perl
# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
if ($ARGV[0] eq "--srand") {
  $n = $ARGV[1];
  $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\"";
  srand($ARGV[1]);
  shift;
  shift;
} else {
  srand(0); # Gives inconsistent behavior if we don't seed.
}
if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we
  # don't understand.
  print "Usage: shuffle_list.pl [--srand N] [input file]  > output\n";
  print "randomizes the order of lines of input.\n";
  exit(1);
}
@lines;
while (<>) {
  push @lines, [ (rand(), $_)] ;
}
@lines = sort { $a->[0] cmp $b->[0] } @lines;
foreach $l (@lines) {
    print $l->[1];
}
examples/aishell/paraformer/utils/split_scp.pl
New file
@@ -0,0 +1,246 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text  file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can.  If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries.  In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
    if ($ARGV[0] eq "-j") {
        shift @ARGV;
        $num_jobs = shift @ARGV;
        $job_id = shift @ARGV;
    }
    if ($ARGV[0] =~ /--utt2spk=(.+)/) {
        $utt2spk_file=$1;
        shift;
    }
    if ($ARGV[0] eq '--one-based') {
        $one_based = 1;
        shift @ARGV;
    }
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
                       $job_id - $one_based >= $num_jobs)) {
  die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
      ($one_based ? " --one-based" : "") . "'\n"
}
$one_based
    and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
    die
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
   or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
 ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
    @OUTPUTS = @ARGV;
} else {
    for ($j = 0; $j < $num_jobs; $j++) {
        if ($j == $job_id) {
            if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
            else { push @OUTPUTS, "-"; }
        } else {
            push @OUTPUTS, "/dev/null";
        }
    }
}
if ($utt2spk_file ne "") {  # We have the --utt2spk option...
    open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
    while(<$u_fh>) {
        @A = split;
        @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
        ($u,$s) = @A;
        $utt2spk{$u} = $s;
    }
    close $u_fh;
    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
    @spkrs = ();
    while(<$i_fh>) {
        @A = split;
        if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
        $u = $A[0];
        $s = $utt2spk{$u};
        defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
        if(!defined $spk_count{$s}) {
            push @spkrs, $s;
            $spk_count{$s} = 0;
            $spk_data{$s} = [];  # ref to new empty array.
        }
        $spk_count{$s}++;
        push @{$spk_data{$s}}, $_;
    }
    # Now split as equally as possible ..
    # First allocate spks to files by allocating an approximately
    # equal number of speakers.
    $numspks = @spkrs;  # number of speakers.
    $numscps = @OUTPUTS; # number of output files.
    if ($numspks < $numscps) {
      die "$0: Refusing to split data because number of speakers $numspks " .
          "is less than the number of output .scp files $numscps\n";
    }
    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
        $scparray[$scpidx] = []; # [] is array reference.
    }
    for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
        $scpidx = int(($spkidx*$numscps) / $numspks);
        $spk = $spkrs[$spkidx];
        push @{$scparray[$scpidx]}, $spk;
        $scpcount[$scpidx] += $spk_count{$spk};
    }
    # Now will try to reassign beginning + ending speakers
    # to different scp's and see if it gets more balanced.
    # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
    # We can show that if considering changing just 2 scp's, we minimize
    # this by minimizing the squared difference in sizes.  This is
    # equivalent to minimizing the absolute difference in sizes.  This
    # shows this method is bound to converge.
    $changed = 1;
    while($changed) {
        $changed = 0;
        for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
            # First try to reassign ending spk of this scp.
            if($scpidx < $numscps-1) {
                $sz = @{$scparray[$scpidx]};
                if($sz > 0) {
                    $spk = $scparray[$scpidx]->[$sz-1];
                    $count = $spk_count{$spk};
                    $nutt1 = $scpcount[$scpidx];
                    $nutt2 = $scpcount[$scpidx+1];
                    if( abs( ($nutt2+$count) - ($nutt1-$count))
                        < abs($nutt2 - $nutt1))  { # Would decrease
                        # size-diff by reassigning spk...
                        $scpcount[$scpidx+1] += $count;
                        $scpcount[$scpidx] -= $count;
                        pop @{$scparray[$scpidx]};
                        unshift @{$scparray[$scpidx+1]}, $spk;
                        $changed = 1;
                    }
                }
            }
            if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
                $spk = $scparray[$scpidx]->[0];
                $count = $spk_count{$spk};
                $nutt1 = $scpcount[$scpidx-1];
                $nutt2 = $scpcount[$scpidx];
                if( abs( ($nutt2-$count) - ($nutt1+$count))
                    < abs($nutt2 - $nutt1))  { # Would decrease
                    # size-diff by reassigning spk...
                    $scpcount[$scpidx-1] += $count;
                    $scpcount[$scpidx] -= $count;
                    shift @{$scparray[$scpidx]};
                    push @{$scparray[$scpidx-1]}, $spk;
                    $changed = 1;
                }
            }
        }
    }
    # Now print out the files...
    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
        $scpfile = $OUTPUTS[$scpidx];
        ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
                         : open($f_fh, '>&', \*STDOUT)) ||
            die "$0: Could not open scp file $scpfile for writing: $!\n";
        $count = 0;
        if(@{$scparray[$scpidx]} == 0) {
            print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
                         "$scpfile (too many splits and too few speakers?)\n";
            $error = 1;
        } else {
            foreach $spk ( @{$scparray[$scpidx]} ) {
                print $f_fh @{$spk_data{$spk}};
                $count += $spk_count{$spk};
            }
            $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
        }
        close($f_fh);
    }
} else {
   # This block is the "normal" case where there is no --utt2spk
   # option and we just break into equal size chunks.
    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
    $numscps = @OUTPUTS;  # size of array.
    @F = ();
    while(<$i_fh>) {
        push @F, $_;
    }
    $numlines = @F;
    if($numlines == 0) {
        print STDERR "$0: error: empty input scp file $inscp\n";
        $error = 1;
    }
    $linesperscp = int( $numlines / $numscps); # the "whole part"..
    $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
    $remainder = $numlines - ($linesperscp * $numscps);
    ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
    # [just doing int() rounds down].
    $n = 0;
    for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
        $scpfile = $OUTPUTS[$scpidx];
        ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
                         : open($o_fh, '>&', \*STDOUT)) ||
            die "$0: Could not open scp file $scpfile for writing: $!\n";
        for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
            print $o_fh $F[$n++];
        }
        close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
    }
    $n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);
examples/aishell/paraformer/utils/text2token.py
New file
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
    start_pos = None
    end_pos = None
    for pos in match_pos:
        if pos[0] <= i < pos[1]:
            start_pos = pos[0]
            end_pos = pos[1]
            break
    return start_pos, end_pos
def get_parser():
    parser = argparse.ArgumentParser(
        description="convert raw text to tokenized text",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--nchar",
        "-n",
        default=1,
        type=int,
        help="number of characters to split, i.e., \
                        aabb -> a a b b with -n 1 and aa bb with -n 2",
    )
    parser.add_argument(
        "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
    )
    parser.add_argument("--space", default="<space>", type=str, help="space symbol")
    parser.add_argument(
        "--non-lang-syms",
        "-l",
        default=None,
        type=str,
        help="list of non-linguistic symobles, e.g., <NOISE> etc.",
    )
    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
    parser.add_argument(
        "--trans_type",
        "-t",
        type=str,
        default="char",
        choices=["char", "phn"],
        help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
                        If trans_type is char,
                        read from SI1279.WRD file -> "bricks are an alternative"
                        Else if trans_type is phn,
                        read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
                        sil t er n ih sil t ih v sil" """,
    )
    return parser
def main():
    parser = get_parser()
    args = parser.parse_args()
    rs = []
    if args.non_lang_syms is not None:
        with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
            nls = [x.rstrip() for x in f.readlines()]
            rs = [re.compile(re.escape(x)) for x in nls]
    if args.text:
        f = codecs.open(args.text, encoding="utf-8")
    else:
        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
    sys.stdout = codecs.getwriter("utf-8")(
        sys.stdout if is_python2 else sys.stdout.buffer
    )
    line = f.readline()
    n = args.nchar
    while line:
        x = line.split()
        print(" ".join(x[: args.skip_ncols]), end=" ")
        a = " ".join(x[args.skip_ncols :])
        # get all matched positions
        match_pos = []
        for r in rs:
            i = 0
            while i >= 0:
                m = r.search(a, i)
                if m:
                    match_pos.append([m.start(), m.end()])
                    i = m.end()
                else:
                    break
        if args.trans_type == "phn":
            a = a.split(" ")
        else:
            if len(match_pos) > 0:
                chars = []
                i = 0
                while i < len(a):
                    start_pos, end_pos = exist_or_not(i, match_pos)
                    if start_pos is not None:
                        chars.append(a[start_pos:end_pos])
                        i = end_pos
                    else:
                        chars.append(a[i])
                        i += 1
                a = chars
            a = [a[j : j + n] for j in range(0, len(a), n)]
        a_flat = []
        for z in a:
            a_flat.append("".join(z))
        a_chars = [z.replace(" ", args.space) for z in a_flat]
        if args.trans_type == "phn":
            a_chars = [z.replace("sil", args.space) for z in a_chars]
        print(" ".join(a_chars))
        line = f.readline()
if __name__ == "__main__":
    main()
examples/aishell/paraformer/utils/text_tokenize.py
New file
@@ -0,0 +1,106 @@
import re
import argparse
def load_dict(seg_file):
    seg_dict = {}
    with open(seg_file, 'r') as infile:
        for line in infile:
            s = line.strip().split()
            key = s[0]
            value = s[1:]
            seg_dict[key] = " ".join(value)
    return seg_dict
def forward_segment(text, dic):
    word_list = []
    i = 0
    while i < len(text):
        longest_word = text[i]
        for j in range(i + 1, len(text) + 1):
            word = text[i:j]
            if word in dic:
                if len(word) > len(longest_word):
                    longest_word = word
        word_list.append(longest_word)
        i += len(longest_word)
    return word_list
def tokenize(txt,
             seg_dict):
    out_txt = ""
    pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
    for word in txt:
        if pattern.match(word):
            if word in seg_dict:
                out_txt += seg_dict[word] + " "
            else:
                out_txt += "<unk>" + " "
        else:
            continue
    return out_txt.strip()
def get_parser():
    parser = argparse.ArgumentParser(
        description="text tokenize",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--text-file",
        "-t",
        default=False,
        required=True,
        type=str,
        help="input text",
    )
    parser.add_argument(
        "--seg-file",
        "-s",
        default=False,
        required=True,
        type=str,
        help="seg file",
    )
    parser.add_argument(
        "--txt-index",
        "-i",
        default=1,
        required=True,
        type=int,
        help="txt index",
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        default=False,
        required=True,
        type=str,
        help="output dir",
    )
    return parser
def main():
    parser = get_parser()
    args = parser.parse_args()
    txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
    shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
    seg_dict = load_dict(args.seg_file)
    with open(args.text_file, 'r') as infile:
        for line in infile:
            s = line.strip().split()
            text_id = s[0]
            text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
            text = tokenize(text_list, seg_dict)
            lens = len(text.strip().split())
            txt_writer.write(text_id + " " + text + '\n')
            shape_writer.write(text_id + " " + str(lens) + '\n')
if __name__ == '__main__':
    main()
examples/aishell/paraformer/utils/text_tokenize.sh
New file
@@ -0,0 +1,35 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
# tokenize configuration
text_dir=$1
seg_file=$2
logdir=$3
output_dir=$4
txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \
  python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \
      -s ${seg_file} -i JOB -o ${txt_dir} \
      || exit 1;
# concatenate the text files together.
for n in $(seq $nj); do
  cat ${txt_dir}/text.$n.txt || exit 1
done > ${output_dir}/text || exit 1
for n in $(seq $nj); do
  cat ${txt_dir}/len.$n || exit 1
done > ${output_dir}/text_shape || exit 1
echo "$0: Succeeded text tokenize"
examples/aishell/paraformer/utils/textnorm_zh.py
New file
@@ -0,0 +1,834 @@
#!/usr/bin/env python3
# coding=utf-8
# Authors:
#   2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
#   2019.9 Jiayu DU
#
# requirements:
#   - python 3.X
# notes: python 2.X WILL fail or produce misleading results
import sys, os, argparse, codecs, string, re
# ================================================================================ #
#                                    basic constant
# ================================================================================ #
CHINESE_DIGIS = u'零一二三四五六七八九'
BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
ZERO_ALT = u'〇'
ONE_ALT = u'幺'
TWO_ALTS = [u'两', u'兩']
POSITIVE = [u'正', u'正']
NEGATIVE = [u'负', u'負']
POINT = [u'点', u'點']
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
FILLER_CHARS = ['呃', '啊']
ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
             '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
             '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
             '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
# 中文数字系统类型
NUMBERING_TYPES = ['low', 'mid', 'high']
CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
                 '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
                  '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
                  '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
                  '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
                  '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
                  '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
CHINESE_PUNC_STOP = '!?。。'
CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
# ================================================================================ #
#                                    basic class
# ================================================================================ #
class ChineseChar(object):
    """
    中文字符
    每个字符对应简体和繁体,
    e.g. 简体 = '负', 繁体 = '負'
    转换时可转换为简体或繁体
    """
    def __init__(self, simplified, traditional):
        self.simplified = simplified
        self.traditional = traditional
        #self.__repr__ = self.__str__
    def __str__(self):
        return self.simplified or self.traditional or None
    def __repr__(self):
        return self.__str__()
class ChineseNumberUnit(ChineseChar):
    """
    中文数字/数位字符
    每个字符除繁简体外还有一个额外的大写字符
    e.g. '陆' 和 '陸'
    """
    def __init__(self, power, simplified, traditional, big_s, big_t):
        super(ChineseNumberUnit, self).__init__(simplified, traditional)
        self.power = power
        self.big_s = big_s
        self.big_t = big_t
    def __str__(self):
        return '10^{}'.format(self.power)
    @classmethod
    def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
        if small_unit:
            return ChineseNumberUnit(power=index + 1,
                                     simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
        elif numbering_type == NUMBERING_TYPES[0]:
            return ChineseNumberUnit(power=index + 8,
                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
        elif numbering_type == NUMBERING_TYPES[1]:
            return ChineseNumberUnit(power=(index + 2) * 4,
                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
        elif numbering_type == NUMBERING_TYPES[2]:
            return ChineseNumberUnit(power=pow(2, index + 3),
                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
        else:
            raise ValueError(
                'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
class ChineseNumberDigit(ChineseChar):
    """
    中文数字字符
    """
    def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
        super(ChineseNumberDigit, self).__init__(simplified, traditional)
        self.value = value
        self.big_s = big_s
        self.big_t = big_t
        self.alt_s = alt_s
        self.alt_t = alt_t
    def __str__(self):
        return str(self.value)
    @classmethod
    def create(cls, i, v):
        return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
class ChineseMath(ChineseChar):
    """
    中文数位字符
    """
    def __init__(self, simplified, traditional, symbol, expression=None):
        super(ChineseMath, self).__init__(simplified, traditional)
        self.symbol = symbol
        self.expression = expression
        self.big_s = simplified
        self.big_t = traditional
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
class NumberSystem(object):
    """
    中文数字系统
    """
    pass
class MathSymbol(object):
    """
    用于中文数字系统的数学符号 (繁/简体), e.g.
    positive = ['正', '正']
    negative = ['负', '負']
    point = ['点', '點']
    """
    def __init__(self, positive, negative, point):
        self.positive = positive
        self.negative = negative
        self.point = point
    def __iter__(self):
        for v in self.__dict__.values():
            yield v
# class OtherSymbol(object):
#     """
#     其他符号
#     """
#
#     def __init__(self, sil):
#         self.sil = sil
#
#     def __iter__(self):
#         for v in self.__dict__.values():
#             yield v
# ================================================================================ #
#                                    basic utils
# ================================================================================ #
def create_system(numbering_type=NUMBERING_TYPES[1]):
    """
    根据数字系统类型返回创建相应的数字系统,默认为 mid
    NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
        low:  '兆' = '亿' * '十' = $10^{9}$,  '京' = '兆' * '十', etc.
        mid:  '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
        high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
    返回对应的数字系统
    """
    # chinese number units of '亿' and larger
    all_larger_units = zip(
        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
    larger_units = [CNU.create(i, v, numbering_type, False)
                    for i, v in enumerate(all_larger_units)]
    # chinese number units of '十, 百, 千, 万'
    all_smaller_units = zip(
        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
    smaller_units = [CNU.create(i, v, small_unit=True)
                     for i, v in enumerate(all_smaller_units)]
    # digis
    chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
                        BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
    digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
    digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
    digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
    digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
    # symbols
    positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
    point_cn = CM(POINT[0], POINT[1], '.', lambda x,
                  y: float(str(x) + '.' + str(y)))
    # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
    system = NumberSystem()
    system.units = smaller_units + larger_units
    system.digits = digits
    system.math = MathSymbol(positive_cn, negative_cn, point_cn)
    # system.symbols = OtherSymbol(sil_cn)
    return system
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
    def get_symbol(char, system):
        for u in system.units:
            if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
                return u
        for d in system.digits:
            if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
                return d
        for m in system.math:
            if char in [m.traditional, m.simplified]:
                return m
    def string2symbols(chinese_string, system):
        int_string, dec_string = chinese_string, ''
        for p in [system.math.point.simplified, system.math.point.traditional]:
            if p in chinese_string:
                int_string, dec_string = chinese_string.split(p)
                break
        return [get_symbol(c, system) for c in int_string], \
               [get_symbol(c, system) for c in dec_string]
    def correct_symbols(integer_symbols, system):
        """
        一百八 to 一百八十
        一亿一千三百万 to 一亿 一千万 三百万
        """
        if integer_symbols and isinstance(integer_symbols[0], CNU):
            if integer_symbols[0].power == 1:
                integer_symbols = [system.digits[1]] + integer_symbols
        if len(integer_symbols) > 1:
            if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
                integer_symbols.append(
                    CNU(integer_symbols[-2].power - 1, None, None, None, None))
        result = []
        unit_count = 0
        for s in integer_symbols:
            if isinstance(s, CND):
                result.append(s)
                unit_count = 0
            elif isinstance(s, CNU):
                current_unit = CNU(s.power, None, None, None, None)
                unit_count += 1
            if unit_count == 1:
                result.append(current_unit)
            elif unit_count > 1:
                for i in range(len(result)):
                    if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
                        result[-i - 1] = CNU(result[-i - 1].power +
                                             current_unit.power, None, None, None, None)
        return result
    def compute_value(integer_symbols):
        """
        Compute the value.
        When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
        e.g. '两千万' = 2000 * 10000 not 2000 + 10000
        """
        value = [0]
        last_power = 0
        for s in integer_symbols:
            if isinstance(s, CND):
                value[-1] = s.value
            elif isinstance(s, CNU):
                value[-1] *= pow(10, s.power)
                if s.power > last_power:
                    value[:-1] = list(map(lambda v: v *
                                                    pow(10, s.power), value[:-1]))
                    last_power = s.power
                value.append(0)
        return sum(value)
    system = create_system(numbering_type)
    int_part, dec_part = string2symbols(chinese_string, system)
    int_part = correct_symbols(int_part, system)
    int_str = str(compute_value(int_part))
    dec_str = ''.join([str(d.value) for d in dec_part])
    if dec_part:
        return '{0}.{1}'.format(int_str, dec_str)
    else:
        return int_str
def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
            traditional=False, alt_zero=False, alt_one=False, alt_two=True,
            use_zeros=True, use_units=True):
    def get_value(value_string, use_zeros=True):
        striped_string = value_string.lstrip('0')
        # record nothing if all zeros
        if not striped_string:
            return []
        # record one digits
        elif len(striped_string) == 1:
            if use_zeros and len(value_string) != len(striped_string):
                return [system.digits[0], system.digits[int(striped_string)]]
            else:
                return [system.digits[int(striped_string)]]
        # recursively record multiple digits
        else:
            result_unit = next(u for u in reversed(
                system.units) if u.power < len(striped_string))
            result_string = value_string[:-result_unit.power]
            return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
    system = create_system(numbering_type)
    int_dec = number_string.split('.')
    if len(int_dec) == 1:
        int_string = int_dec[0]
        dec_string = ""
    elif len(int_dec) == 2:
        int_string = int_dec[0]
        dec_string = int_dec[1]
    else:
        raise ValueError(
            "invalid input num string with more than one dot: {}".format(number_string))
    if use_units and len(int_string) > 1:
        result_symbols = get_value(int_string)
    else:
        result_symbols = [system.digits[int(c)] for c in int_string]
    dec_symbols = [system.digits[int(c)] for c in dec_string]
    if dec_string:
        result_symbols += [system.math.point] + dec_symbols
    if alt_two:
        liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
                    system.digits[2].big_s, system.digits[2].big_t)
        for i, v in enumerate(result_symbols):
            if isinstance(v, CND) and v.value == 2:
                next_symbol = result_symbols[i +
                                             1] if i < len(result_symbols) - 1 else None
                previous_symbol = result_symbols[i - 1] if i > 0 else None
                if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
                    if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
                        result_symbols[i] = liang
    # if big is True, '两' will not be used and `alt_two` has no impact on output
    if big:
        attr_name = 'big_'
        if traditional:
            attr_name += 't'
        else:
            attr_name += 's'
    else:
        if traditional:
            attr_name = 'traditional'
        else:
            attr_name = 'simplified'
    result = ''.join([getattr(s, attr_name) for s in result_symbols])
    # if not use_zeros:
    #     result = result.strip(getattr(system.digits[0], attr_name))
    if alt_zero:
        result = result.replace(
            getattr(system.digits[0], attr_name), system.digits[0].alt_s)
    if alt_one:
        result = result.replace(
            getattr(system.digits[1], attr_name), system.digits[1].alt_s)
    for i, p in enumerate(POINT):
        if result.startswith(p):
            return CHINESE_DIGIS[0] + result
    # ^10, 11, .., 19
    if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
                                          SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
            result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
        result = result[1:]
    return result
# ================================================================================ #
#                          different types of rewriters
# ================================================================================ #
class Cardinal:
    """
    CARDINAL类
    """
    def __init__(self, cardinal=None, chntext=None):
        self.cardinal = cardinal
        self.chntext = chntext
    def chntext2cardinal(self):
        return chn2num(self.chntext)
    def cardinal2chntext(self):
        return num2chn(self.cardinal)
class Digit:
    """
    DIGIT类
    """
    def __init__(self, digit=None, chntext=None):
        self.digit = digit
        self.chntext = chntext
    # def chntext2digit(self):
    #     return chn2num(self.chntext)
    def digit2chntext(self):
        return num2chn(self.digit, alt_two=False, use_units=False)
class TelePhone:
    """
    TELEPHONE类
    """
    def __init__(self, telephone=None, raw_chntext=None, chntext=None):
        self.telephone = telephone
        self.raw_chntext = raw_chntext
        self.chntext = chntext
    # def chntext2telephone(self):
    #     sil_parts = self.raw_chntext.split('<SIL>')
    #     self.telephone = '-'.join([
    #         str(chn2num(p)) for p in sil_parts
    #     ])
    #     return self.telephone
    def telephone2chntext(self, fixed=False):
        if fixed:
            sil_parts = self.telephone.split('-')
            self.raw_chntext = '<SIL>'.join([
                num2chn(part, alt_two=False, use_units=False) for part in sil_parts
            ])
            self.chntext = self.raw_chntext.replace('<SIL>', '')
        else:
            sp_parts = self.telephone.strip('+').split()
            self.raw_chntext = '<SP>'.join([
                num2chn(part, alt_two=False, use_units=False) for part in sp_parts
            ])
            self.chntext = self.raw_chntext.replace('<SP>', '')
        return self.chntext
class Fraction:
    """
    FRACTION类
    """
    def __init__(self, fraction=None, chntext=None):
        self.fraction = fraction
        self.chntext = chntext
    def chntext2fraction(self):
        denominator, numerator = self.chntext.split('分之')
        return chn2num(numerator) + '/' + chn2num(denominator)
    def fraction2chntext(self):
        numerator, denominator = self.fraction.split('/')
        return num2chn(denominator) + '分之' + num2chn(numerator)
class Date:
    """
    DATE类
    """
    def __init__(self, date=None, chntext=None):
        self.date = date
        self.chntext = chntext
    # def chntext2date(self):
    #     chntext = self.chntext
    #     try:
    #         year, other = chntext.strip().split('年', maxsplit=1)
    #         year = Digit(chntext=year).digit2chntext() + '年'
    #     except ValueError:
    #         other = chntext
    #         year = ''
    #     if other:
    #         try:
    #             month, day = other.strip().split('月', maxsplit=1)
    #             month = Cardinal(chntext=month).chntext2cardinal() + '月'
    #         except ValueError:
    #             day = chntext
    #             month = ''
    #         if day:
    #             day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
    #     else:
    #         month = ''
    #         day = ''
    #     date = year + month + day
    #     self.date = date
    #     return self.date
    def date2chntext(self):
        date = self.date
        try:
            year, other = date.strip().split('年', 1)
            year = Digit(digit=year).digit2chntext() + '年'
        except ValueError:
            other = date
            year = ''
        if other:
            try:
                month, day = other.strip().split('月', 1)
                month = Cardinal(cardinal=month).cardinal2chntext() + '月'
            except ValueError:
                day = date
                month = ''
            if day:
                day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
        else:
            month = ''
            day = ''
        chntext = year + month + day
        self.chntext = chntext
        return self.chntext
class Money:
    """
    MONEY类
    """
    def __init__(self, money=None, chntext=None):
        self.money = money
        self.chntext = chntext
    # def chntext2money(self):
    #     return self.money
    def money2chntext(self):
        money = self.money
        pattern = re.compile(r'(\d+(\.\d+)?)')
        matchers = pattern.findall(money)
        if matchers:
            for matcher in matchers:
                money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
        self.chntext = money
        return self.chntext
class Percentage:
    """
    PERCENTAGE类
    """
    def __init__(self, percentage=None, chntext=None):
        self.percentage = percentage
        self.chntext = chntext
    def chntext2percentage(self):
        return chn2num(self.chntext.strip().strip('百分之')) + '%'
    def percentage2chntext(self):
        return '百分之' + num2chn(self.percentage.strip().strip('%'))
def remove_erhua(text, er_whitelist):
    """
    去除儿化音词中的儿:
    他女儿在那边儿 -> 他女儿在那边
    """
    er_pattern = re.compile(er_whitelist)
    new_str=''
    while re.search('儿',text):
        a = re.search('儿',text).span()
        remove_er_flag = 0
        if er_pattern.search(text):
            b = er_pattern.search(text).span()
            if b[0] <= a[0]:
                remove_er_flag = 1
        if remove_er_flag == 0 :
            new_str = new_str + text[0:a[0]]
            text = text[a[1]:]
        else:
            new_str = new_str + text[0:b[1]]
            text = text[b[1]:]
    text = new_str + text
    return text
# ================================================================================ #
#                            NSW Normalizer
# ================================================================================ #
class NSWNormalizer:
    def __init__(self, raw_text):
        self.raw_text = '^' + raw_text + '$'
        self.norm_text = ''
    def _particular(self):
        text = self.norm_text
        pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
        matchers = pattern.findall(text)
        if matchers:
            # print('particular')
            for matcher in matchers:
                text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
        self.norm_text = text
        return self.norm_text
    def normalize(self):
        text = self.raw_text
        # 规范化日期
        pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
        matchers = pattern.findall(text)
        if matchers:
            #print('date')
            for matcher in matchers:
                text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
        # 规范化金钱
        pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
        matchers = pattern.findall(text)
        if matchers:
            #print('money')
            for matcher in matchers:
                text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
        # 规范化固话/手机号码
        # 手机
        # http://www.jihaoba.com/news/show/13680
        # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
        # 联通:130、131、132、156、155、186、185、176
        # 电信:133、153、189、180、181、177
        pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
        matchers = pattern.findall(text)
        if matchers:
            #print('telephone')
            for matcher in matchers:
                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
        # 固话
        pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
        matchers = pattern.findall(text)
        if matchers:
            # print('fixed telephone')
            for matcher in matchers:
                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
        # 规范化分数
        pattern = re.compile(r"(\d+/\d+)")
        matchers = pattern.findall(text)
        if matchers:
            #print('fraction')
            for matcher in matchers:
                text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
        # 规范化百分数
        text = text.replace('%', '%')
        pattern = re.compile(r"(\d+(\.\d+)?%)")
        matchers = pattern.findall(text)
        if matchers:
            #print('percentage')
            for matcher in matchers:
                text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
        # 规范化纯数+量词
        pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
        matchers = pattern.findall(text)
        if matchers:
            #print('cardinal+quantifier')
            for matcher in matchers:
                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
        # 规范化数字编号
        pattern = re.compile(r"(\d{4,32})")
        matchers = pattern.findall(text)
        if matchers:
            #print('digit')
            for matcher in matchers:
                text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
        # 规范化纯数
        pattern = re.compile(r"(\d+(\.\d+)?)")
        matchers = pattern.findall(text)
        if matchers:
            #print('cardinal')
            for matcher in matchers:
                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
        self.norm_text = text
        self._particular()
        return self.norm_text.lstrip('^').rstrip('$')
def nsw_test_case(raw_text):
    print('I:' + raw_text)
    print('O:' + NSWNormalizer(raw_text).normalize())
    print('')
def nsw_test():
    nsw_test_case('固话:0595-23865596或23880880。')
    nsw_test_case('固话:0595-23865596或23880880。')
    nsw_test_case('手机:+86 19859213959或15659451527。')
    nsw_test_case('分数:32477/76391。')
    nsw_test_case('百分数:80.03%。')
    nsw_test_case('编号:31520181154418。')
    nsw_test_case('纯数:2983.07克或12345.60米。')
    nsw_test_case('日期:1999年2月20日或09年3月15号。')
    nsw_test_case('金钱:12块5,34.5元,20.1万')
    nsw_test_case('特殊:O2O或B2C。')
    nsw_test_case('3456万吨')
    nsw_test_case('2938个')
    nsw_test_case('938')
    nsw_test_case('今天吃了115个小笼包231个馒头')
    nsw_test_case('有62%的概率')
if __name__ == '__main__':
    #nsw_test()
    p = argparse.ArgumentParser()
    p.add_argument('ifile', help='input filename, assume utf-8 encoding')
    p.add_argument('ofile', help='output filename')
    p.add_argument('--to_upper', action='store_true', help='convert to upper case')
    p.add_argument('--to_lower', action='store_true', help='convert to lower case')
    p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
    p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"')
    p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"')
    p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
    args = p.parse_args()
    ifile = codecs.open(args.ifile, 'r', 'utf8')
    ofile = codecs.open(args.ofile, 'w+', 'utf8')
    n = 0
    for l in ifile:
        key = ''
        text = ''
        if args.has_key:
            cols = l.split(maxsplit=1)
            key = cols[0]
            if len(cols) == 2:
                text = cols[1].strip()
            else:
                text = ''
        else:
            text = l.strip()
        # cases
        if args.to_upper and args.to_lower:
            sys.stderr.write('text norm: to_upper OR to_lower?')
            exit(1)
        if args.to_upper:
            text = text.upper()
        if args.to_lower:
            text = text.lower()
        # Filler chars removal
        if args.remove_fillers:
            for ch in FILLER_CHARS:
                text = text.replace(ch, '')
        if args.remove_erhua:
            text = remove_erhua(text, ER_WHITELIST)
        # NSW(Non-Standard-Word) normalization
        text = NSWNormalizer(text).normalize()
        # Punctuations removal
        old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
        new_chars = ' ' * len(old_chars)
        del_chars = ''
        text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
        #
        if args.has_key:
            ofile.write(key + '\t' + text + '\n')
        else:
            ofile.write(text + '\n')
        n += 1
        if n % args.log_interval == 0:
            sys.stderr.write("text norm: {} lines done.\n".format(n))
    sys.stderr.write("text norm: {} lines done in total.\n".format(n))
    ifile.close()
    ofile.close()
examples/aishell/run.sh
File was deleted
examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -5,7 +5,14 @@
#local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
#git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
## generate jsonl from wav.scp and text.txt
#python funasr/datasets/audio_datasets/scp2jsonl.py \
#++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
#++data_type_list='["source", "target"]' \
#++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
# torchrun \
# --nnodes 1 \
# --nproc_per_node 1 \
python funasr/bin/train.py \
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+model_revision="v2.0.4" \
funasr/bin/compute_audio_cmvn.py
New file
@@ -0,0 +1,123 @@
import os
import json
import numpy as np
import torch
import hydra
import logging
from omegaconf import DictConfig, OmegaConf
from funasr.register import tables
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
    if kwargs.get("debug", False):
        import pdb; pdb.set_trace()
    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
    main(**kwargs)
def main(**kwargs):
    print(kwargs)
    # set random seed
    tables.print()
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    tokenizer = kwargs.get("tokenizer", None)
    # build frontend if frontend is none None
    frontend = kwargs.get("frontend", None)
    if frontend is not None:
        frontend_class = tables.frontend_classes.get(frontend)
        frontend = frontend_class(**kwargs["frontend_conf"])
        kwargs["frontend"] = frontend
        kwargs["input_size"] = frontend.output_size()
    # dataset
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
    batch_sampler_train = None
    if batch_sampler is not None:
        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
        dataset_conf = kwargs.get("dataset_conf")
        dataset_conf["batch_type"] = "example"
        dataset_conf["batch_size"] = 1
        batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
    dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                collate_fn=dataset_train.collator,
                                                batch_sampler=batch_sampler_train,
                                                num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
                                                pin_memory=True)
    iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
    total_frames = 0
    for batch_idx, batch in enumerate(dataloader_train):
        if batch_idx >= iter_stop:
            break
        fbank = batch["speech"].numpy()[0, :, :]
        if total_frames == 0:
            mean_stats = fbank
            var_stats = np.square(fbank)
        else:
            mean_stats += np.sum(fbank, axis=0)
            var_stats += np.sum(np.square(fbank), axis=0)
        total_frames += fbank.shape[0]
    cmvn_info = {
        'mean_stats': list(mean_stats.tolist()),
        'var_stats': list(var_stats.tolist()),
        'total_frames': total_frames
    }
    cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
    with open(cmvn_file, 'w') as fout:
        fout.write(json.dumps(cmvn_info))
    mean = -1.0 * mean_stats / total_frames
    var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
    dims = mean.shape[0]
    am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
    with open(am_mvn, 'w') as fout:
        fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
        mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
        fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
        fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
        var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
        fout.write("<LearnRateCoef> 0 " + var_str + '\n')
        fout.write("</Nnet>" + '\n')
if __name__ == "__main__":
    main_hydra()
    """
    python funasr/bin/compute_status.py \
    --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
    --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
    ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
    ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
    ++dataset_conf.num_workers=32
    """
funasr/bin/train.py
@@ -1,3 +1,6 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
import os
import sys
import torch
@@ -144,9 +147,8 @@
    # dataset
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
                               **kwargs.get("dataset_conf"))
    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
funasr/datasets/audio_datasets/datasets.py
@@ -19,7 +19,7 @@
                  **kwargs):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
@@ -63,9 +63,14 @@
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        ids = self.tokenizer.encode(target)
        if self.tokenizer:
            ids = self.tokenizer.encode(target)
            text = torch.tensor(ids, dtype=torch.int64)
        else:
            ids = target
            text = ids
        ids_lengths = len(ids)
        text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
        text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
        return {"speech": speech[0, :, :],
                "speech_lengths": speech_lengths,
@@ -83,11 +88,13 @@
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if data_list[0].dtype == torch.int64:
                pad_value = self.int_pad_value
            else:
                pad_value = self.float_pad_value
            outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return outputs
funasr/datasets/audio_datasets/index_ds.py
@@ -1,6 +1,9 @@
import os
import json
import torch
import logging
import concurrent.futures
import librosa
import torch.distributed as dist
from funasr.register import tables
@@ -71,9 +74,19 @@
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
class IndexDSJsonlRankFull(torch.utils.data.Dataset):
    
    def __init__(self, path):
    def __init__(self, path: str, **kwargs):
        super().__init__()
        
        if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
            from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list
            jsonl_outdir = os.path.dirname(path[0])
            jsonl_name = "datalist_train.jsonl" if kwargs.get("is_training", True) else "datalist_val.jsonl"
            jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name)
            if not os.path.exists(jsonl_file_out):
                print(f"datalist is: {path}, generate jsonl from it")
                gen_jsonl_from_wav_text_list(path, jsonl_file_out=jsonl_file_out, **kwargs)
            path = jsonl_file_out
        contents = []
        with open(path, encoding='utf-8') as fin:
            for line in fin:
funasr/datasets/audio_datasets/scp2jsonl.py
New file
@@ -0,0 +1,94 @@
import os
import json
import torch
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import concurrent.futures
import librosa
import torch.distributed as dist
def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
    try:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    except:
        rank = 0
        world_size = 1
    cpu_cores = os.cpu_count() or 1
    if rank == 0:
        json_dict = {}
        for data_type, data_file in zip(data_type_list, path):
            json_dict[data_type] = {}
            with open(data_file, "r") as f:
                data_file_lists = f.readlines()
                lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
                    futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
                    for future in concurrent.futures.as_completed(futures):
                        json_dict[data_type].update(future.result())
            # print(json_dict)
        with open(jsonl_file_out, "w") as f:
            for key in json_dict[data_type_list[0]].keys():
                jsonl_line = {"key": key}
                for data_file in data_type_list:
                    jsonl_line.update(json_dict[data_file][key])
                jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
                f.write(jsonl_line+"\n")
                f.flush()
    else:
        pass
    if world_size > 1:
        dist.barrier()
def parse_context_length(data_list: list, data_type: str):
    res = {}
    for i, line in enumerate(data_list):
        key, line = line.strip().split(maxsplit=1)
        line = line.strip()
        if os.path.exists(line):
            waveform, _ = librosa.load(line, sr=16000)
            sample_num = len(waveform)
            context_len = int(sample_num//16000*1000/10)
        else:
            context_len = len(line)
        res[key] = {data_type: line, f"{data_type}_len": context_len}
    return res
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
    """
    python funasr/datasets/audio_datasets/scp2jsonl.py \
    ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
    ++data_type_list='["source", "target"]' \
    ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
    """
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
    data_type_list = kwargs.get("data_type_list", ("source", "target"))
    jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
    gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
if __name__ == "__main__":
    main_hydra()
funasr/download/download_dataset_from_hub.py
New file
@@ -0,0 +1,11 @@
def download_dataset():
    pass
def download_dataset_from_ms(**kwargs):
    from modelscope.msdatasets import MsDataset
    dataset_name = kwargs.get("dataset_name", 'speech_asr/speech_asr_aishell1_trainsets')
    subset_name = kwargs.get("subset_name", 'default')
    split = kwargs.get("split", 'train')
    data_dump_dir = kwargs.get("data_dump_dir", None)
    ds = MsDataset.load(dataset_name=dataset_name, subset_name=subset_name, split=split, cache_dir=data_dump_dir)
funasr/models/model_hf/__init__.py
funasr/register.py
@@ -52,8 +52,8 @@
            registry = getattr(self, register_tables_key)
            registry_key = key if key is not None else target_class.__name__
            
            assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
                registry_key, target_class, register_tables_key)
            # assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
            #     registry_key, target_class, register_tables_key)
            
            registry[registry_key] = target_class
            
funasr/train_utils/trainer.py
@@ -204,25 +204,25 @@
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                print("before, GPU, memory: {:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                                     torch.cuda.max_memory_allocated()/1024/1024/1024,
                                     torch.cuda.memory_reserved()/1024/1024/1024,
                                     torch.cuda.max_memory_reserved()/1024/1024/1024,
                                     ))
                # print("before, GPU, memory: {:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                #                      torch.cuda.max_memory_allocated()/1024/1024/1024,
                #                      torch.cuda.memory_reserved()/1024/1024/1024,
                #                      torch.cuda.max_memory_reserved()/1024/1024/1024,
                #                      ))
                retval = self.model(**batch)
                torch.cuda.empty_cache()
                print("after, GPU, memory: {:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB, "
                      "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                                     torch.cuda.max_memory_allocated()/1024/1024/1024,
                                     torch.cuda.memory_reserved()/1024/1024/1024,
                                     torch.cuda.max_memory_reserved()/1024/1024/1024,
                                     ))
                # print("after, GPU, memory: {:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                #                      torch.cuda.max_memory_allocated()/1024/1024/1024,
                #                      torch.cuda.memory_reserved()/1024/1024/1024,
                #                      torch.cuda.max_memory_reserved()/1024/1024/1024,
                #                      ))
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
@@ -275,12 +275,21 @@
            pbar.update(1)
            if self.local_rank == 0:
                gpu_info = "GPU, memory: {:.3f} GB, " \
                           "{:.3f} GB, "\
                           "{:.3f} GB, "\
                           "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                                             torch.cuda.max_memory_allocated()/1024/1024/1024,
                                             torch.cuda.memory_reserved()/1024/1024/1024,
                                             torch.cuda.max_memory_reserved()/1024/1024/1024,
                                             )
                description = (
                    f"Train epoch: {epoch}/{self.max_epoch}, "
                    f"step {batch_idx}/{len(self.dataloader_train)}, "
                    f"{speed_stats}, "
                    f"(loss: {loss.detach().cpu().item():.3f}), "
                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                    f"{gpu_info}"
                )
                pbar.set_description(description)
                if self.writer: