| New file |
| | |
| | | ID0012W0013 当客户风险承受能力评估依据发生变化时 |
| | | ID0012W0014 杨涛不得不将工厂关掉 |
| New file |
| | |
| | | ID0012W0013 /Users/zhifu/funasr_github/test_local/aishell2_dev_ios/wav/D0012/ID0012W0013.wav |
| | | ID0012W0014 /Users/zhifu/funasr_github/test_local/aishell2_dev_ios/wav/D0012/ID0012W0014.wav |
| | |
| | | |
| | | # network architecture |
| | | model: funasr.cli.models.paraformer:Paraformer |
| | | model: Paraformer |
| | | model_conf: |
| | | ctc_weight: 0.3 |
| | | lsm_weight: 0.1 |
| | |
| | | sampling_ratio: 0.4 |
| | | use_1st_decoder_loss: true |
| | | |
| | | |
| | | # encoder related |
| | | encoder: conformer |
| | | # encoder |
| | | encoder: ConformerEncoder |
| | | encoder_conf: |
| | | output_size: 256 # dimension of attention |
| | | attention_heads: 4 |
| | |
| | | use_cnn_module: true |
| | | cnn_module_kernel: 15 |
| | | |
| | | # decoder related |
| | | decoder: paraformer_decoder_san |
| | | # decoder |
| | | decoder: ParaformerSANDecoder |
| | | decoder_conf: |
| | | attention_heads: 4 |
| | | linear_units: 2048 |
| | |
| | | self_attention_dropout_rate: 0.0 |
| | | src_attention_dropout_rate: 0.0 |
| | | |
| | | # predictor |
| | | predictor: CifPredictor |
| | | predictor_conf: |
| | | idim: 256 |
| | | threshold: 1.0 |
| | | l_order: 1 |
| | | r_order: 1 |
| | | tail_threshold: 0.45 |
| | | |
| | | # frontend related |
| | | frontend: wav_frontend |
| | | frontend: WavFrontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | window: hamming |
| | |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | |
| | | |
| | | train_conf: |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 150 |
| | | val_scheduler_criterion: |
| | | - valid |
| | | - acc |
| | | best_model_criterion: |
| | | - - valid |
| | | - acc |
| | | - max |
| | | keep_nbest_models: 10 |
| | | log_interval: 50 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.0005 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 30000 |
| | | |
| | | specaug: specaug |
| | | specaug: SpecAug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | |
| | | - 40 |
| | | num_time_mask: 2 |
| | | |
| | | predictor: cif_predictor |
| | | predictor_conf: |
| | | idim: 256 |
| | | threshold: 1.0 |
| | | l_order: 1 |
| | | r_order: 1 |
| | | tail_threshold: 0.45 |
| | | train_conf: |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 150 |
| | | keep_nbest_models: 10 |
| | | avg_nbest_model: 5 |
| | | log_interval: 50 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.0005 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 30000 |
| | | |
| | | dataset: AudioDataset |
| | | dataset_conf: |
| | | data_names: speech,text |
| | | data_types: sound,text |
| | | index_ds: IndexDSJsonl |
| | | batch_sampler: RankFullLocalShuffleBatchSampler |
| | | batch_type: example # example or length |
| | | batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; |
| | | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, |
| | | buffer_size: 1024 |
| | | shuffle: True |
| | | shuffle_conf: |
| | | shuffle_size: 2048 |
| | | sort_size: 500 |
| | | batch_conf: |
| | | batch_type: example |
| | | batch_size: 2 |
| | | num_workers: 8 |
| | | num_workers: 0 |
| | | |
| | | tokenizer: CharTokenizer |
| | | tokenizer_conf: |
| | | unk_symbol: <unk> |
| | | split_with_space: true |
| | | |
| | | |
| | | normalize: null |
| | | ctc_conf: |
| | | dropout_rate: 0.0 |
| | | ctc_type: builtin |
| | | reduce: true |
| | | ignore_nan_grad: true |
| | | normalize: null |
| | | |
| | | |
| New file |
| | |
| | | |
| | | cmd="funasr/bin/train.py" |
| | | |
| | | python $cmd \ |
| | | --config-path "/Users/zhifu/funasr_github/test_local/funasr_cli_egs" \ |
| | | --config-name "config.yaml" \ |
| | | ++token_list="/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \ |
| | | ++train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \ |
| | | ++output_dir="/nfs/zhifu.gzf/ckpt/funasr2/exp1" |
| New file |
| | |
| | | #!/bin/bash |
| | | |
| | | # Copyright 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | #. ./path.sh || exit 1; |
| | | |
| | | if [ $# != 3 ]; then |
| | | echo "Usage: $0 <audio-path> <text-path> <output-path>" |
| | | echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" |
| | | exit 1; |
| | | fi |
| | | |
| | | aishell_audio_dir=$1 |
| | | aishell_text=$2/aishell_transcript_v0.8.txt |
| | | output_dir=$3 |
| | | |
| | | train_dir=$output_dir/data/local/train |
| | | dev_dir=$output_dir/data/local/dev |
| | | test_dir=$output_dir/data/local/test |
| | | tmp_dir=$output_dir/data/local/tmp |
| | | |
| | | mkdir -p $train_dir |
| | | mkdir -p $dev_dir |
| | | mkdir -p $test_dir |
| | | mkdir -p $tmp_dir |
| | | |
| | | # data directory check |
| | | if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then |
| | | echo "Error: $0 requires two directory arguments" |
| | | exit 1; |
| | | fi |
| | | |
| | | # find wav audio file for train, dev and test resp. |
| | | find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist |
| | | n=`cat $tmp_dir/wav.flist | wc -l` |
| | | [ $n -ne 141925 ] && \ |
| | | echo Warning: expected 141925 data data files, found $n |
| | | |
| | | grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; |
| | | grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; |
| | | grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; |
| | | |
| | | rm -r $tmp_dir |
| | | |
| | | # Transcriptions preparation |
| | | for dir in $train_dir $dev_dir $test_dir; do |
| | | echo Preparing $dir transcriptions |
| | | sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list |
| | | paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt |
| | | awk '{print $1}' $dir/transcripts.txt > $dir/utt.list |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp |
| | | sort -u $dir/transcripts.txt > $dir/text |
| | | done |
| | | |
| | | mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test |
| | | |
| | | for f in wav.scp text; do |
| | | cp $train_dir/$f $output_dir/data/train/$f || exit 1; |
| | | cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; |
| | | cp $test_dir/$f $output_dir/data/test/$f || exit 1; |
| | | done |
| | | |
| | | echo "$0: AISHELL data preparation succeeded" |
| | | exit 0; |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) |
| | | # 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | remove_archive=false |
| | | |
| | | if [ "$1" == --remove-archive ]; then |
| | | remove_archive=true |
| | | shift |
| | | fi |
| | | |
| | | if [ $# -ne 3 ]; then |
| | | echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>" |
| | | echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" |
| | | echo "With --remove-archive it will remove the archive after successfully un-tarring it." |
| | | echo "<corpus-part> can be one of: data_aishell, resource_aishell." |
| | | fi |
| | | |
| | | data=$1 |
| | | url=$2 |
| | | part=$3 |
| | | |
| | | if [ ! -d "$data" ]; then |
| | | echo "$0: no such directory $data" |
| | | exit 1; |
| | | fi |
| | | |
| | | part_ok=false |
| | | list="data_aishell resource_aishell" |
| | | for x in $list; do |
| | | if [ "$part" == $x ]; then part_ok=true; fi |
| | | done |
| | | if ! $part_ok; then |
| | | echo "$0: expected <corpus-part> to be one of $list, but got '$part'" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -z "$url" ]; then |
| | | echo "$0: empty URL base." |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -f $data/$part/.complete ]; then |
| | | echo "$0: data part $part was already successfully extracted, nothing to do." |
| | | exit 0; |
| | | fi |
| | | |
| | | # sizes of the archive files in bytes. |
| | | sizes="15582913665 1246920" |
| | | |
| | | if [ -f $data/$part.tgz ]; then |
| | | size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') |
| | | size_ok=false |
| | | for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done |
| | | if ! $size_ok; then |
| | | echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" |
| | | echo "does not equal the size of one of the archives." |
| | | rm $data/$part.tgz |
| | | else |
| | | echo "$data/$part.tgz exists and appears to be complete." |
| | | fi |
| | | fi |
| | | |
| | | if [ ! -f $data/$part.tgz ]; then |
| | | if ! command -v wget >/dev/null; then |
| | | echo "$0: wget is not installed." |
| | | exit 1; |
| | | fi |
| | | full_url=$url/$part.tgz |
| | | echo "$0: downloading data from $full_url. This may take some time, please be patient." |
| | | |
| | | cd $data || exit 1 |
| | | if ! wget --no-check-certificate $full_url; then |
| | | echo "$0: error executing wget $full_url" |
| | | exit 1; |
| | | fi |
| | | fi |
| | | |
| | | cd $data || exit 1 |
| | | |
| | | if ! tar -xvzf $part.tgz; then |
| | | echo "$0: error un-tarring archive $data/$part.tgz" |
| | | exit 1; |
| | | fi |
| | | |
| | | touch $data/$part/.complete |
| | | |
| | | if [ $part == "data_aishell" ]; then |
| | | cd $data/$part/wav || exit 1 |
| | | for wav in ./*.tar.gz; do |
| | | echo "Extracting wav from $wav" |
| | | tar -zxf $wav && rm $wav |
| | | done |
| | | fi |
| | | |
| | | echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" |
| | | |
| | | if $remove_archive; then |
| | | echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." |
| | | rm $data/$part.tgz |
| | | fi |
| | | |
| | | exit 0; |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | . ./path.sh || exit 1; |
| | | |
| | | # machines configuration |
| | | CUDA_VISIBLE_DEVICES="0,1" |
| | | gpu_num=2 |
| | | gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding |
| | | # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob |
| | | njob=1 |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | lang=zh |
| | | token_type=char |
| | | stage=0 |
| | | stop_stage=5 |
| | | |
| | | # feature configuration |
| | | nj=64 |
| | | |
| | | # data |
| | | raw_data=../raw_data |
| | | data_url=www.openslr.org/resources/33 |
| | | |
| | | # exp tag |
| | | tag="exp1" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | train_set=train |
| | | valid_set=dev |
| | | test_sets="dev test" |
| | | |
| | | asr_config=conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml |
| | | model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}" |
| | | |
| | | #inference_config=conf/decode_asr_transformer_noctc_1best.yaml |
| | | #inference_asr_model=valid.acc.ave_10best.pb |
| | | |
| | | ## you can set gpu num for decoding here |
| | | #gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default |
| | | #ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') |
| | | # |
| | | #if ${gpu_inference}; then |
| | | # inference_nj=$[${ngpu}*${njob}] |
| | | # _ngpu=1 |
| | | #else |
| | | # inference_nj=$njob |
| | | # _ngpu=0 |
| | | #fi |
| | | |
| | | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then |
| | | echo "stage -1: Data Download" |
| | | local/download_and_untar.sh ${raw_data} ${data_url} data_aishell |
| | | local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell |
| | | fi |
| | | |
| | | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then |
| | | echo "stage 0: Data preparation" |
| | | # Data preparation |
| | | local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir} |
| | | for x in train dev test; do |
| | | cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org |
| | | paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ |
| | | > ${feats_dir}/data/${x}/text |
| | | utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org |
| | | mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text |
| | | |
| | | python funasr/datasets/audio_datasets/scp2jsonl.py \ |
| | | ++scp_file_list='["${feats_dir}/data/${x}/wav.scp", "${feats_dir}/data/${x}/text"]' \ |
| | | ++data_type_list='["source", "target"]' \ |
| | | ++jsonl_file_out=${feats_dir}/data/${x}/audio_datasets.jsonl |
| | | done |
| | | fi |
| | | |
| | | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then |
| | | echo "stage 1: Feature and CMVN Generation" |
| | | # utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0 |
| | | python funasr/bin/compute_audio_cmvn.py \ |
| | | --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \ |
| | | --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ |
| | | ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ |
| | | ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \ |
| | | ++dataset_conf.num_workers=$nj |
| | | fi |
| | | |
| | | token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt |
| | | echo "dictionary: ${token_list}" |
| | | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then |
| | | echo "stage 2: Dictionary Preparation" |
| | | mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/ |
| | | |
| | | echo "make a dictionary" |
| | | echo "<blank>" > ${token_list} |
| | | echo "<s>" >> ${token_list} |
| | | echo "</s>" >> ${token_list} |
| | | utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \ |
| | | | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} |
| | | echo "<unk>" >> ${token_list} |
| | | fi |
| | | |
| | | # LM Training Stage |
| | | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then |
| | | echo "stage 3: LM Training" |
| | | fi |
| | | |
| | | # ASR Training Stage |
| | | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then |
| | | echo "stage 4: ASR Training" |
| | | |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | funasr/bin/train.py \ |
| | | --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \ |
| | | --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ |
| | | ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ |
| | | ++cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \ |
| | | ++token_list="${token_list}" \ |
| | | ++output_dir="${exp_dir}/exp/${model_dir}" |
| | | fi |
| | | |
| | | # |
| | | ## Testing Stage |
| | | #if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | # echo "stage 5: Inference" |
| | | # for dset in ${test_sets}; do |
| | | # asr_exp=${exp_dir}/exp/${model_dir} |
| | | # inference_tag="$(basename "${inference_config}" .yaml)" |
| | | # _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" |
| | | # _logdir="${_dir}/logdir" |
| | | # if [ -d ${_dir} ]; then |
| | | # echo "${_dir} is already exists. if you want to decode again, please delete this dir first." |
| | | # exit 0 |
| | | # fi |
| | | # mkdir -p "${_logdir}" |
| | | # _data="${feats_dir}/data/${dset}" |
| | | # key_file=${_data}/${scp} |
| | | # num_scp_file="$(<${key_file} wc -l)" |
| | | # _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") |
| | | # split_scps= |
| | | # for n in $(seq "${_nj}"); do |
| | | # split_scps+=" ${_logdir}/keys.${n}.scp" |
| | | # done |
| | | # # shellcheck disable=SC2086 |
| | | # utils/split_scp.pl "${key_file}" ${split_scps} |
| | | # _opts= |
| | | # if [ -n "${inference_config}" ]; then |
| | | # _opts+="--config ${inference_config} " |
| | | # fi |
| | | # ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | # python -m funasr.bin.asr_inference_launch \ |
| | | # --batch_size 1 \ |
| | | # --ngpu "${_ngpu}" \ |
| | | # --njob ${njob} \ |
| | | # --gpuid_list ${gpuid_list} \ |
| | | # --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ |
| | | # --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ |
| | | # --key_file "${_logdir}"/keys.JOB.scp \ |
| | | # --asr_train_config "${asr_exp}"/config.yaml \ |
| | | # --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ |
| | | # --output_dir "${_logdir}"/output.JOB \ |
| | | # --mode paraformer \ |
| | | # ${_opts} |
| | | # |
| | | # for f in token token_int score text; do |
| | | # if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then |
| | | # for i in $(seq "${_nj}"); do |
| | | # cat "${_logdir}/output.${i}/1best_recog/${f}" |
| | | # done | sort -k1 >"${_dir}/${f}" |
| | | # fi |
| | | # done |
| | | # python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | # python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | # python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | # tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | # cat ${_dir}/text.cer.txt |
| | | # done |
| | | #fi |
| | | # |
| | | ## Prepare files for ModelScope fine-tuning and inference |
| | | #if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then |
| | | # echo "stage 6: ModelScope Preparation" |
| | | # cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn |
| | | # vocab_size=$(cat ${token_list} | wc -l) |
| | | # python utils/gen_modelscope_configuration.py \ |
| | | # --am_model_name $inference_asr_model \ |
| | | # --mode paraformer \ |
| | | # --model_name paraformer \ |
| | | # --dataset aishell \ |
| | | # --output_dir $exp_dir/exp/$model_dir \ |
| | | # --vocab_size $vocab_size \ |
| | | # --nat _nat \ |
| | | # --tag $tag |
| | | #fi |
| New file |
| | |
| | | from transformers import AutoTokenizer, AutoModel, pipeline |
| | | import numpy as np |
| | | import sys |
| | | import os |
| | | import torch |
| | | from kaldiio import WriteHelper |
| | | import re |
| | | text_file_json = sys.argv[1] |
| | | out_ark = sys.argv[2] |
| | | out_scp = sys.argv[3] |
| | | out_shape = sys.argv[4] |
| | | device = int(sys.argv[5]) |
| | | model_path = sys.argv[6] |
| | | |
| | | model = AutoModel.from_pretrained(model_path) |
| | | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | | extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device) |
| | | |
| | | with open(text_file_json, 'r') as f: |
| | | js = f.readlines() |
| | | |
| | | |
| | | f_shape = open(out_shape, "w") |
| | | with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer: |
| | | with torch.no_grad(): |
| | | for idx, line in enumerate(js): |
| | | id, tokens = line.strip().split(" ", 1) |
| | | tokens = re.sub(" ", "", tokens.strip()) |
| | | tokens = ' '.join([j for j in tokens]) |
| | | token_num = len(tokens.split(" ")) |
| | | outputs = extractor(tokens) |
| | | outputs = np.array(outputs) |
| | | embeds = outputs[0, 1:-1, :] |
| | | |
| | | token_num_embeds, dim = embeds.shape |
| | | if token_num == token_num_embeds: |
| | | writer(id, embeds) |
| | | shape_line = "{} {},{}\n".format(id, token_num_embeds, dim) |
| | | f_shape.write(shape_line) |
| | | else: |
| | | print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens)) |
| | | |
| | | |
| | | |
| | | f_shape.close() |
| | | |
| | | |
| New file |
| | |
| | | #!/usr/bin/env perl |
| | | # Copyright 2010-2012 Microsoft Corporation |
| | | # Johns Hopkins University (author: Daniel Povey) |
| | | |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | # This script takes a list of utterance-ids or any file whose first field |
| | | # of each line is an utterance-id, and filters an scp |
| | | # file (or any file whose "n-th" field is an utterance id), printing |
| | | # out only those lines whose "n-th" field is in id_list. The index of |
| | | # the "n-th" field is 1, by default, but can be changed by using |
| | | # the -f <n> switch |
| | | |
| | | $exclude = 0; |
| | | $field = 1; |
| | | $shifted = 0; |
| | | |
| | | do { |
| | | $shifted=0; |
| | | if ($ARGV[0] eq "--exclude") { |
| | | $exclude = 1; |
| | | shift @ARGV; |
| | | $shifted=1; |
| | | } |
| | | if ($ARGV[0] eq "-f") { |
| | | $field = $ARGV[1]; |
| | | shift @ARGV; shift @ARGV; |
| | | $shifted=1 |
| | | } |
| | | } while ($shifted); |
| | | |
| | | if(@ARGV < 1 || @ARGV > 2) { |
| | | die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" . |
| | | "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . |
| | | "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . |
| | | "only the lines that were *not* in id_list.\n" . |
| | | "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . |
| | | "If your older scripts (written before Oct 2014) stopped working and you used the\n" . |
| | | "-f option, add 1 to the argument.\n" . |
| | | "See also: scripts/filter_scp.pl .\n"; |
| | | } |
| | | |
| | | |
| | | $idlist = shift @ARGV; |
| | | open(F, "<$idlist") || die "Could not open id-list file $idlist"; |
| | | while(<F>) { |
| | | @A = split; |
| | | @A>=1 || die "Invalid id-list file line $_"; |
| | | $seen{$A[0]} = 1; |
| | | } |
| | | |
| | | if ($field == 1) { # Treat this as special case, since it is common. |
| | | while(<>) { |
| | | $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; |
| | | # $1 is what we filter on. |
| | | if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { |
| | | print $_; |
| | | } |
| | | } |
| | | } else { |
| | | while(<>) { |
| | | @A = split; |
| | | @A > 0 || die "Invalid scp file line $_"; |
| | | @A >= $field || die "Invalid scp file line $_"; |
| | | if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { |
| | | print $_; |
| | | } |
| | | } |
| | | } |
| | | |
| | | # tests: |
| | | # the following should print "foo 1" |
| | | # ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo) |
| | | # the following should print "bar 2". |
| | | # ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2) |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | echo "$0 $@" |
| | | data_dir=$1 |
| | | |
| | | if [ ! -f ${data_dir}/wav.scp ]; then |
| | | echo "$0: wav.scp is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/text ]; then |
| | | echo "$0: text is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | |
| | | |
| | | mkdir -p ${data_dir}/.backup |
| | | |
| | | awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id |
| | | awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id |
| | | |
| | | sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id |
| | | |
| | | cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp |
| | | cp ${data_dir}/text ${data_dir}/.backup/text |
| | | |
| | | mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak |
| | | mv ${data_dir}/text ${data_dir}/text.bak |
| | | |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text |
| | | |
| | | rm ${data_dir}/wav.scp.bak |
| | | rm ${data_dir}/text.bak |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | echo "$0 $@" |
| | | data_dir=$1 |
| | | |
| | | if [ ! -f ${data_dir}/feats.scp ]; then |
| | | echo "$0: feats.scp is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/text ]; then |
| | | echo "$0: text is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/speech_shape ]; then |
| | | echo "$0: feature lengths is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/text_shape ]; then |
| | | echo "$0: text lengths is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | mkdir -p ${data_dir}/.backup |
| | | |
| | | awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id |
| | | awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id |
| | | |
| | | sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id |
| | | |
| | | cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp |
| | | cp ${data_dir}/text ${data_dir}/.backup/text |
| | | cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape |
| | | cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape |
| | | |
| | | mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak |
| | | mv ${data_dir}/text ${data_dir}/text.bak |
| | | mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak |
| | | mv ${data_dir}/text_shape ${data_dir}/text_shape.bak |
| | | |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape |
| | | |
| | | rm ${data_dir}/feats.scp.bak |
| | | rm ${data_dir}/text.bak |
| | | rm ${data_dir}/speech_shape.bak |
| | | rm ${data_dir}/text_shape.bak |
| | | |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); |
| | | # Arnab Ghoshal, Karel Vesely |
| | | |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | # Parse command-line options. |
| | | # To be sourced by another script (as in ". parse_options.sh"). |
| | | # Option format is: --option-name arg |
| | | # and shell variable "option_name" gets set to value "arg." |
| | | # The exception is --help, which takes no arguments, but prints the |
| | | # $help_message variable (if defined). |
| | | |
| | | |
| | | ### |
| | | ### The --config file options have lower priority to command line |
| | | ### options, so we need to import them first... |
| | | ### |
| | | |
| | | # Now import all the configs specified by command-line, in left-to-right order |
| | | for ((argpos=1; argpos<$#; argpos++)); do |
| | | if [ "${!argpos}" == "--config" ]; then |
| | | argpos_plus1=$((argpos+1)) |
| | | config=${!argpos_plus1} |
| | | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 |
| | | . $config # source the config file. |
| | | fi |
| | | done |
| | | |
| | | |
| | | ### |
| | | ### Now we process the command line options |
| | | ### |
| | | while true; do |
| | | [ -z "${1:-}" ] && break; # break if there are no arguments |
| | | case "$1" in |
| | | # If the enclosing script is called with --help option, print the help |
| | | # message and exit. Scripts should put help messages in $help_message |
| | | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; |
| | | else printf "$help_message\n" 1>&2 ; fi; |
| | | exit 0 ;; |
| | | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" |
| | | exit 1 ;; |
| | | # If the first command-line argument begins with "--" (e.g. --foo-bar), |
| | | # then work out the variable name as $name, which will equal "foo_bar". |
| | | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; |
| | | # Next we test whether the variable in question is undefned-- if so it's |
| | | # an invalid option and we die. Note: $0 evaluates to the name of the |
| | | # enclosing script. |
| | | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar |
| | | # is undefined. We then have to wrap this test inside "eval" because |
| | | # foo_bar is itself inside a variable ($name). |
| | | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; |
| | | |
| | | oldval="`eval echo \\$$name`"; |
| | | # Work out whether we seem to be expecting a Boolean argument. |
| | | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then |
| | | was_bool=true; |
| | | else |
| | | was_bool=false; |
| | | fi |
| | | |
| | | # Set the variable to the right value-- the escaped quotes make it work if |
| | | # the option had spaces, like --cmd "queue.pl -sync y" |
| | | eval $name=\"$2\"; |
| | | |
| | | # Check that Boolean-valued arguments are really Boolean. |
| | | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then |
| | | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 |
| | | exit 1; |
| | | fi |
| | | shift 2; |
| | | ;; |
| | | *) break; |
| | | esac |
| | | done |
| | | |
| | | |
| | | # Check for an empty argument to the --cmd option, which can easily occur as a |
| | | # result of scripting errors. |
| | | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; |
| | | |
| | | |
| | | true; # so this script returns exit code 0. |
| New file |
| | |
| | | #!/usr/bin/env perl |
| | | |
| | | # Copyright 2013 Johns Hopkins University (author: Daniel Povey) |
| | | |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | if ($ARGV[0] eq "--srand") { |
| | | $n = $ARGV[1]; |
| | | $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\""; |
| | | srand($ARGV[1]); |
| | | shift; |
| | | shift; |
| | | } else { |
| | | srand(0); # Gives inconsistent behavior if we don't seed. |
| | | } |
| | | |
| | | if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we |
| | | # don't understand. |
| | | print "Usage: shuffle_list.pl [--srand N] [input file] > output\n"; |
| | | print "randomizes the order of lines of input.\n"; |
| | | exit(1); |
| | | } |
| | | |
| | | @lines; |
| | | while (<>) { |
| | | push @lines, [ (rand(), $_)] ; |
| | | } |
| | | |
| | | @lines = sort { $a->[0] cmp $b->[0] } @lines; |
| | | foreach $l (@lines) { |
| | | print $l->[1]; |
| | | } |
| New file |
| | |
| | | #!/usr/bin/env perl |
| | | |
| | | # Copyright 2010-2011 Microsoft Corporation |
| | | |
| | | # See ../../COPYING for clarification regarding multiple authors |
| | | # |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | # This program splits up any kind of .scp or archive-type file. |
| | | # If there is no utt2spk option it will work on any text file and |
| | | # will split it up with an approximately equal number of lines in |
| | | # each but. |
| | | # With the --utt2spk option it will work on anything that has the |
| | | # utterance-id as the first entry on each line; the utt2spk file is |
| | | # of the form "utterance speaker" (on each line). |
| | | # It splits it into equal size chunks as far as it can. If you use the utt2spk |
| | | # option it will make sure these chunks coincide with speaker boundaries. In |
| | | # this case, if there are more chunks than speakers (and in some other |
| | | # circumstances), some of the resulting chunks will be empty and it will print |
| | | # an error message and exit with nonzero status. |
| | | # You will normally call this like: |
| | | # split_scp.pl scp scp.1 scp.2 scp.3 ... |
| | | # or |
| | | # split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... |
| | | # Note that you can use this script to split the utt2spk file itself, |
| | | # e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... |
| | | |
| | | # You can also call the scripts like: |
| | | # split_scp.pl -j 3 0 scp scp.0 |
| | | # [note: with this option, it assumes zero-based indexing of the split parts, |
| | | # i.e. the second number must be 0 <= n < num-jobs.] |
| | | |
| | | use warnings; |
| | | |
| | | $num_jobs = 0; |
| | | $job_id = 0; |
| | | $utt2spk_file = ""; |
| | | $one_based = 0; |
| | | |
| | | for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { |
| | | if ($ARGV[0] eq "-j") { |
| | | shift @ARGV; |
| | | $num_jobs = shift @ARGV; |
| | | $job_id = shift @ARGV; |
| | | } |
| | | if ($ARGV[0] =~ /--utt2spk=(.+)/) { |
| | | $utt2spk_file=$1; |
| | | shift; |
| | | } |
| | | if ($ARGV[0] eq '--one-based') { |
| | | $one_based = 1; |
| | | shift @ARGV; |
| | | } |
| | | } |
| | | |
| | | if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || |
| | | $job_id - $one_based >= $num_jobs)) { |
| | | die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . |
| | | ($one_based ? " --one-based" : "") . "'\n" |
| | | } |
| | | |
| | | $one_based |
| | | and $job_id--; |
| | | |
| | | if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { |
| | | die |
| | | "Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ... |
| | | or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp] |
| | | ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; |
| | | } |
| | | |
| | | $error = 0; |
| | | $inscp = shift @ARGV; |
| | | if ($num_jobs == 0) { # without -j option |
| | | @OUTPUTS = @ARGV; |
| | | } else { |
| | | for ($j = 0; $j < $num_jobs; $j++) { |
| | | if ($j == $job_id) { |
| | | if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } |
| | | else { push @OUTPUTS, "-"; } |
| | | } else { |
| | | push @OUTPUTS, "/dev/null"; |
| | | } |
| | | } |
| | | } |
| | | |
| | | if ($utt2spk_file ne "") { # We have the --utt2spk option... |
| | | open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; |
| | | while(<$u_fh>) { |
| | | @A = split; |
| | | @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; |
| | | ($u,$s) = @A; |
| | | $utt2spk{$u} = $s; |
| | | } |
| | | close $u_fh; |
| | | open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; |
| | | @spkrs = (); |
| | | while(<$i_fh>) { |
| | | @A = split; |
| | | if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } |
| | | $u = $A[0]; |
| | | $s = $utt2spk{$u}; |
| | | defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; |
| | | if(!defined $spk_count{$s}) { |
| | | push @spkrs, $s; |
| | | $spk_count{$s} = 0; |
| | | $spk_data{$s} = []; # ref to new empty array. |
| | | } |
| | | $spk_count{$s}++; |
| | | push @{$spk_data{$s}}, $_; |
| | | } |
| | | # Now split as equally as possible .. |
| | | # First allocate spks to files by allocating an approximately |
| | | # equal number of speakers. |
| | | $numspks = @spkrs; # number of speakers. |
| | | $numscps = @OUTPUTS; # number of output files. |
| | | if ($numspks < $numscps) { |
| | | die "$0: Refusing to split data because number of speakers $numspks " . |
| | | "is less than the number of output .scp files $numscps\n"; |
| | | } |
| | | for($scpidx = 0; $scpidx < $numscps; $scpidx++) { |
| | | $scparray[$scpidx] = []; # [] is array reference. |
| | | } |
| | | for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { |
| | | $scpidx = int(($spkidx*$numscps) / $numspks); |
| | | $spk = $spkrs[$spkidx]; |
| | | push @{$scparray[$scpidx]}, $spk; |
| | | $scpcount[$scpidx] += $spk_count{$spk}; |
| | | } |
| | | |
| | | # Now will try to reassign beginning + ending speakers |
| | | # to different scp's and see if it gets more balanced. |
| | | # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. |
| | | # We can show that if considering changing just 2 scp's, we minimize |
| | | # this by minimizing the squared difference in sizes. This is |
| | | # equivalent to minimizing the absolute difference in sizes. This |
| | | # shows this method is bound to converge. |
| | | |
| | | $changed = 1; |
| | | while($changed) { |
| | | $changed = 0; |
| | | for($scpidx = 0; $scpidx < $numscps; $scpidx++) { |
| | | # First try to reassign ending spk of this scp. |
| | | if($scpidx < $numscps-1) { |
| | | $sz = @{$scparray[$scpidx]}; |
| | | if($sz > 0) { |
| | | $spk = $scparray[$scpidx]->[$sz-1]; |
| | | $count = $spk_count{$spk}; |
| | | $nutt1 = $scpcount[$scpidx]; |
| | | $nutt2 = $scpcount[$scpidx+1]; |
| | | if( abs( ($nutt2+$count) - ($nutt1-$count)) |
| | | < abs($nutt2 - $nutt1)) { # Would decrease |
| | | # size-diff by reassigning spk... |
| | | $scpcount[$scpidx+1] += $count; |
| | | $scpcount[$scpidx] -= $count; |
| | | pop @{$scparray[$scpidx]}; |
| | | unshift @{$scparray[$scpidx+1]}, $spk; |
| | | $changed = 1; |
| | | } |
| | | } |
| | | } |
| | | if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { |
| | | $spk = $scparray[$scpidx]->[0]; |
| | | $count = $spk_count{$spk}; |
| | | $nutt1 = $scpcount[$scpidx-1]; |
| | | $nutt2 = $scpcount[$scpidx]; |
| | | if( abs( ($nutt2-$count) - ($nutt1+$count)) |
| | | < abs($nutt2 - $nutt1)) { # Would decrease |
| | | # size-diff by reassigning spk... |
| | | $scpcount[$scpidx-1] += $count; |
| | | $scpcount[$scpidx] -= $count; |
| | | shift @{$scparray[$scpidx]}; |
| | | push @{$scparray[$scpidx-1]}, $spk; |
| | | $changed = 1; |
| | | } |
| | | } |
| | | } |
| | | } |
| | | # Now print out the files... |
| | | for($scpidx = 0; $scpidx < $numscps; $scpidx++) { |
| | | $scpfile = $OUTPUTS[$scpidx]; |
| | | ($scpfile ne '-' ? open($f_fh, '>', $scpfile) |
| | | : open($f_fh, '>&', \*STDOUT)) || |
| | | die "$0: Could not open scp file $scpfile for writing: $!\n"; |
| | | $count = 0; |
| | | if(@{$scparray[$scpidx]} == 0) { |
| | | print STDERR "$0: eError: split_scp.pl producing empty .scp file " . |
| | | "$scpfile (too many splits and too few speakers?)\n"; |
| | | $error = 1; |
| | | } else { |
| | | foreach $spk ( @{$scparray[$scpidx]} ) { |
| | | print $f_fh @{$spk_data{$spk}}; |
| | | $count += $spk_count{$spk}; |
| | | } |
| | | $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; |
| | | } |
| | | close($f_fh); |
| | | } |
| | | } else { |
| | | # This block is the "normal" case where there is no --utt2spk |
| | | # option and we just break into equal size chunks. |
| | | |
| | | open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; |
| | | |
| | | $numscps = @OUTPUTS; # size of array. |
| | | @F = (); |
| | | while(<$i_fh>) { |
| | | push @F, $_; |
| | | } |
| | | $numlines = @F; |
| | | if($numlines == 0) { |
| | | print STDERR "$0: error: empty input scp file $inscp\n"; |
| | | $error = 1; |
| | | } |
| | | $linesperscp = int( $numlines / $numscps); # the "whole part".. |
| | | $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; |
| | | $remainder = $numlines - ($linesperscp * $numscps); |
| | | ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; |
| | | # [just doing int() rounds down]. |
| | | $n = 0; |
| | | for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { |
| | | $scpfile = $OUTPUTS[$scpidx]; |
| | | ($scpfile ne '-' ? open($o_fh, '>', $scpfile) |
| | | : open($o_fh, '>&', \*STDOUT)) || |
| | | die "$0: Could not open scp file $scpfile for writing: $!\n"; |
| | | for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { |
| | | print $o_fh $F[$n++]; |
| | | } |
| | | close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; |
| | | } |
| | | $n == $numlines || die "$n != $numlines [code error]"; |
| | | } |
| | | |
| | | exit ($error); |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | |
| | | # Copyright 2017 Johns Hopkins University (Shinji Watanabe) |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | |
| | | import argparse |
| | | import codecs |
| | | import re |
| | | import sys |
| | | |
| | | is_python2 = sys.version_info[0] == 2 |
| | | |
| | | |
| | | def exist_or_not(i, match_pos): |
| | | start_pos = None |
| | | end_pos = None |
| | | for pos in match_pos: |
| | | if pos[0] <= i < pos[1]: |
| | | start_pos = pos[0] |
| | | end_pos = pos[1] |
| | | break |
| | | |
| | | return start_pos, end_pos |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = argparse.ArgumentParser( |
| | | description="convert raw text to tokenized text", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument( |
| | | "--nchar", |
| | | "-n", |
| | | default=1, |
| | | type=int, |
| | | help="number of characters to split, i.e., \ |
| | | aabb -> a a b b with -n 1 and aa bb with -n 2", |
| | | ) |
| | | parser.add_argument( |
| | | "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" |
| | | ) |
| | | parser.add_argument("--space", default="<space>", type=str, help="space symbol") |
| | | parser.add_argument( |
| | | "--non-lang-syms", |
| | | "-l", |
| | | default=None, |
| | | type=str, |
| | | help="list of non-linguistic symobles, e.g., <NOISE> etc.", |
| | | ) |
| | | parser.add_argument("text", type=str, default=False, nargs="?", help="input text") |
| | | parser.add_argument( |
| | | "--trans_type", |
| | | "-t", |
| | | type=str, |
| | | default="char", |
| | | choices=["char", "phn"], |
| | | help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - |
| | | If trans_type is char, |
| | | read from SI1279.WRD file -> "bricks are an alternative" |
| | | Else if trans_type is phn, |
| | | read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l |
| | | sil t er n ih sil t ih v sil" """, |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | def main(): |
| | | parser = get_parser() |
| | | args = parser.parse_args() |
| | | |
| | | rs = [] |
| | | if args.non_lang_syms is not None: |
| | | with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: |
| | | nls = [x.rstrip() for x in f.readlines()] |
| | | rs = [re.compile(re.escape(x)) for x in nls] |
| | | |
| | | if args.text: |
| | | f = codecs.open(args.text, encoding="utf-8") |
| | | else: |
| | | f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) |
| | | |
| | | sys.stdout = codecs.getwriter("utf-8")( |
| | | sys.stdout if is_python2 else sys.stdout.buffer |
| | | ) |
| | | line = f.readline() |
| | | n = args.nchar |
| | | while line: |
| | | x = line.split() |
| | | print(" ".join(x[: args.skip_ncols]), end=" ") |
| | | a = " ".join(x[args.skip_ncols :]) |
| | | |
| | | # get all matched positions |
| | | match_pos = [] |
| | | for r in rs: |
| | | i = 0 |
| | | while i >= 0: |
| | | m = r.search(a, i) |
| | | if m: |
| | | match_pos.append([m.start(), m.end()]) |
| | | i = m.end() |
| | | else: |
| | | break |
| | | |
| | | if args.trans_type == "phn": |
| | | a = a.split(" ") |
| | | else: |
| | | if len(match_pos) > 0: |
| | | chars = [] |
| | | i = 0 |
| | | while i < len(a): |
| | | start_pos, end_pos = exist_or_not(i, match_pos) |
| | | if start_pos is not None: |
| | | chars.append(a[start_pos:end_pos]) |
| | | i = end_pos |
| | | else: |
| | | chars.append(a[i]) |
| | | i += 1 |
| | | a = chars |
| | | |
| | | a = [a[j : j + n] for j in range(0, len(a), n)] |
| | | |
| | | a_flat = [] |
| | | for z in a: |
| | | a_flat.append("".join(z)) |
| | | |
| | | a_chars = [z.replace(" ", args.space) for z in a_flat] |
| | | if args.trans_type == "phn": |
| | | a_chars = [z.replace("sil", args.space) for z in a_chars] |
| | | print(" ".join(a_chars)) |
| | | line = f.readline() |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| New file |
| | |
| | | import re |
| | | import argparse |
| | | |
| | | |
| | | def load_dict(seg_file): |
| | | seg_dict = {} |
| | | with open(seg_file, 'r') as infile: |
| | | for line in infile: |
| | | s = line.strip().split() |
| | | key = s[0] |
| | | value = s[1:] |
| | | seg_dict[key] = " ".join(value) |
| | | return seg_dict |
| | | |
| | | |
| | | def forward_segment(text, dic): |
| | | word_list = [] |
| | | i = 0 |
| | | while i < len(text): |
| | | longest_word = text[i] |
| | | for j in range(i + 1, len(text) + 1): |
| | | word = text[i:j] |
| | | if word in dic: |
| | | if len(word) > len(longest_word): |
| | | longest_word = word |
| | | word_list.append(longest_word) |
| | | i += len(longest_word) |
| | | return word_list |
| | | |
| | | |
| | | def tokenize(txt, |
| | | seg_dict): |
| | | out_txt = "" |
| | | pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])") |
| | | for word in txt: |
| | | if pattern.match(word): |
| | | if word in seg_dict: |
| | | out_txt += seg_dict[word] + " " |
| | | else: |
| | | out_txt += "<unk>" + " " |
| | | else: |
| | | continue |
| | | return out_txt.strip() |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = argparse.ArgumentParser( |
| | | description="text tokenize", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument( |
| | | "--text-file", |
| | | "-t", |
| | | default=False, |
| | | required=True, |
| | | type=str, |
| | | help="input text", |
| | | ) |
| | | parser.add_argument( |
| | | "--seg-file", |
| | | "-s", |
| | | default=False, |
| | | required=True, |
| | | type=str, |
| | | help="seg file", |
| | | ) |
| | | parser.add_argument( |
| | | "--txt-index", |
| | | "-i", |
| | | default=1, |
| | | required=True, |
| | | type=int, |
| | | help="txt index", |
| | | ) |
| | | parser.add_argument( |
| | | "--output-dir", |
| | | "-o", |
| | | default=False, |
| | | required=True, |
| | | type=str, |
| | | help="output dir", |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | def main(): |
| | | parser = get_parser() |
| | | args = parser.parse_args() |
| | | |
| | | txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w') |
| | | shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w') |
| | | seg_dict = load_dict(args.seg_file) |
| | | with open(args.text_file, 'r') as infile: |
| | | for line in infile: |
| | | s = line.strip().split() |
| | | text_id = s[0] |
| | | text_list = forward_segment("".join(s[1:]).lower(), seg_dict) |
| | | text = tokenize(text_list, seg_dict) |
| | | lens = len(text.strip().split()) |
| | | txt_writer.write(text_id + " " + text + '\n') |
| | | shape_writer.write(text_id + " " + str(lens) + '\n') |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | main() |
| | | |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | |
| | | # Begin configuration section. |
| | | nj=32 |
| | | cmd=utils/run.pl |
| | | |
| | | echo "$0 $@" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # tokenize configuration |
| | | text_dir=$1 |
| | | seg_file=$2 |
| | | logdir=$3 |
| | | output_dir=$4 |
| | | |
| | | txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt |
| | | mkdir -p ${logdir} |
| | | |
| | | $cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \ |
| | | python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \ |
| | | -s ${seg_file} -i JOB -o ${txt_dir} \ |
| | | || exit 1; |
| | | |
| | | # concatenate the text files together. |
| | | for n in $(seq $nj); do |
| | | cat ${txt_dir}/text.$n.txt || exit 1 |
| | | done > ${output_dir}/text || exit 1 |
| | | |
| | | for n in $(seq $nj); do |
| | | cat ${txt_dir}/len.$n || exit 1 |
| | | done > ${output_dir}/text_shape || exit 1 |
| | | |
| | | echo "$0: Succeeded text tokenize" |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # coding=utf-8 |
| | | |
| | | # Authors: |
| | | # 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) |
| | | # 2019.9 Jiayu DU |
| | | # |
| | | # requirements: |
| | | # - python 3.X |
| | | # notes: python 2.X WILL fail or produce misleading results |
| | | |
| | | import sys, os, argparse, codecs, string, re |
| | | |
| | | # ================================================================================ # |
| | | # basic constant |
| | | # ================================================================================ # |
| | | CHINESE_DIGIS = u'零一二三四五六七八九' |
| | | BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖' |
| | | BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖' |
| | | SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万' |
| | | SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬' |
| | | LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载' |
| | | LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載' |
| | | SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万' |
| | | SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬' |
| | | |
| | | ZERO_ALT = u'〇' |
| | | ONE_ALT = u'幺' |
| | | TWO_ALTS = [u'两', u'兩'] |
| | | |
| | | POSITIVE = [u'正', u'正'] |
| | | NEGATIVE = [u'负', u'負'] |
| | | POINT = [u'点', u'點'] |
| | | # PLUS = [u'加', u'加'] |
| | | # SIL = [u'杠', u'槓'] |
| | | |
| | | FILLER_CHARS = ['呃', '啊'] |
| | | ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \ |
| | | '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \ |
| | | '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \ |
| | | '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)' |
| | | |
| | | # 中文数字系统类型 |
| | | NUMBERING_TYPES = ['low', 'mid', 'high'] |
| | | |
| | | CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \ |
| | | '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)' |
| | | CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' |
| | | COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \ |
| | | '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \ |
| | | '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \ |
| | | '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \ |
| | | '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \ |
| | | '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)' |
| | | |
| | | # punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) |
| | | CHINESE_PUNC_STOP = '!?。。' |
| | | CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏' |
| | | CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP |
| | | |
| | | # ================================================================================ # |
| | | # basic class |
| | | # ================================================================================ # |
| | | class ChineseChar(object): |
| | | """ |
| | | 中文字符 |
| | | 每个字符对应简体和繁体, |
| | | e.g. 简体 = '负', 繁体 = '負' |
| | | 转换时可转换为简体或繁体 |
| | | """ |
| | | |
| | | def __init__(self, simplified, traditional): |
| | | self.simplified = simplified |
| | | self.traditional = traditional |
| | | #self.__repr__ = self.__str__ |
| | | |
| | | def __str__(self): |
| | | return self.simplified or self.traditional or None |
| | | |
| | | def __repr__(self): |
| | | return self.__str__() |
| | | |
| | | |
| | | class ChineseNumberUnit(ChineseChar): |
| | | """ |
| | | 中文数字/数位字符 |
| | | 每个字符除繁简体外还有一个额外的大写字符 |
| | | e.g. '陆' 和 '陸' |
| | | """ |
| | | |
| | | def __init__(self, power, simplified, traditional, big_s, big_t): |
| | | super(ChineseNumberUnit, self).__init__(simplified, traditional) |
| | | self.power = power |
| | | self.big_s = big_s |
| | | self.big_t = big_t |
| | | |
| | | def __str__(self): |
| | | return '10^{}'.format(self.power) |
| | | |
| | | @classmethod |
| | | def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): |
| | | |
| | | if small_unit: |
| | | return ChineseNumberUnit(power=index + 1, |
| | | simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]) |
| | | elif numbering_type == NUMBERING_TYPES[0]: |
| | | return ChineseNumberUnit(power=index + 8, |
| | | simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) |
| | | elif numbering_type == NUMBERING_TYPES[1]: |
| | | return ChineseNumberUnit(power=(index + 2) * 4, |
| | | simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) |
| | | elif numbering_type == NUMBERING_TYPES[2]: |
| | | return ChineseNumberUnit(power=pow(2, index + 3), |
| | | simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) |
| | | else: |
| | | raise ValueError( |
| | | 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type)) |
| | | |
| | | |
| | | class ChineseNumberDigit(ChineseChar): |
| | | """ |
| | | 中文数字字符 |
| | | """ |
| | | |
| | | def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): |
| | | super(ChineseNumberDigit, self).__init__(simplified, traditional) |
| | | self.value = value |
| | | self.big_s = big_s |
| | | self.big_t = big_t |
| | | self.alt_s = alt_s |
| | | self.alt_t = alt_t |
| | | |
| | | def __str__(self): |
| | | return str(self.value) |
| | | |
| | | @classmethod |
| | | def create(cls, i, v): |
| | | return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) |
| | | |
| | | |
| | | class ChineseMath(ChineseChar): |
| | | """ |
| | | 中文数位字符 |
| | | """ |
| | | |
| | | def __init__(self, simplified, traditional, symbol, expression=None): |
| | | super(ChineseMath, self).__init__(simplified, traditional) |
| | | self.symbol = symbol |
| | | self.expression = expression |
| | | self.big_s = simplified |
| | | self.big_t = traditional |
| | | |
| | | |
| | | CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath |
| | | |
| | | |
| | | class NumberSystem(object): |
| | | """ |
| | | 中文数字系统 |
| | | """ |
| | | pass |
| | | |
| | | |
| | | class MathSymbol(object): |
| | | """ |
| | | 用于中文数字系统的数学符号 (繁/简体), e.g. |
| | | positive = ['正', '正'] |
| | | negative = ['负', '負'] |
| | | point = ['点', '點'] |
| | | """ |
| | | |
| | | def __init__(self, positive, negative, point): |
| | | self.positive = positive |
| | | self.negative = negative |
| | | self.point = point |
| | | |
| | | def __iter__(self): |
| | | for v in self.__dict__.values(): |
| | | yield v |
| | | |
| | | |
| | | # class OtherSymbol(object): |
| | | # """ |
| | | # 其他符号 |
| | | # """ |
| | | # |
| | | # def __init__(self, sil): |
| | | # self.sil = sil |
| | | # |
| | | # def __iter__(self): |
| | | # for v in self.__dict__.values(): |
| | | # yield v |
| | | |
| | | |
| | | # ================================================================================ # |
| | | # basic utils |
| | | # ================================================================================ # |
| | | def create_system(numbering_type=NUMBERING_TYPES[1]): |
| | | """ |
| | | 根据数字系统类型返回创建相应的数字系统,默认为 mid |
| | | NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 |
| | | low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. |
| | | mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. |
| | | high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. |
| | | 返回对应的数字系统 |
| | | """ |
| | | |
| | | # chinese number units of '亿' and larger |
| | | all_larger_units = zip( |
| | | LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) |
| | | larger_units = [CNU.create(i, v, numbering_type, False) |
| | | for i, v in enumerate(all_larger_units)] |
| | | # chinese number units of '十, 百, 千, 万' |
| | | all_smaller_units = zip( |
| | | SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) |
| | | smaller_units = [CNU.create(i, v, small_unit=True) |
| | | for i, v in enumerate(all_smaller_units)] |
| | | # digis |
| | | chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, |
| | | BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) |
| | | digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] |
| | | digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT |
| | | digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT |
| | | digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] |
| | | |
| | | # symbols |
| | | positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x) |
| | | negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x) |
| | | point_cn = CM(POINT[0], POINT[1], '.', lambda x, |
| | | y: float(str(x) + '.' + str(y))) |
| | | # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) |
| | | system = NumberSystem() |
| | | system.units = smaller_units + larger_units |
| | | system.digits = digits |
| | | system.math = MathSymbol(positive_cn, negative_cn, point_cn) |
| | | # system.symbols = OtherSymbol(sil_cn) |
| | | return system |
| | | |
| | | |
| | | def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): |
| | | |
| | | def get_symbol(char, system): |
| | | for u in system.units: |
| | | if char in [u.traditional, u.simplified, u.big_s, u.big_t]: |
| | | return u |
| | | for d in system.digits: |
| | | if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: |
| | | return d |
| | | for m in system.math: |
| | | if char in [m.traditional, m.simplified]: |
| | | return m |
| | | |
| | | def string2symbols(chinese_string, system): |
| | | int_string, dec_string = chinese_string, '' |
| | | for p in [system.math.point.simplified, system.math.point.traditional]: |
| | | if p in chinese_string: |
| | | int_string, dec_string = chinese_string.split(p) |
| | | break |
| | | return [get_symbol(c, system) for c in int_string], \ |
| | | [get_symbol(c, system) for c in dec_string] |
| | | |
| | | def correct_symbols(integer_symbols, system): |
| | | """ |
| | | 一百八 to 一百八十 |
| | | 一亿一千三百万 to 一亿 一千万 三百万 |
| | | """ |
| | | |
| | | if integer_symbols and isinstance(integer_symbols[0], CNU): |
| | | if integer_symbols[0].power == 1: |
| | | integer_symbols = [system.digits[1]] + integer_symbols |
| | | |
| | | if len(integer_symbols) > 1: |
| | | if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): |
| | | integer_symbols.append( |
| | | CNU(integer_symbols[-2].power - 1, None, None, None, None)) |
| | | |
| | | result = [] |
| | | unit_count = 0 |
| | | for s in integer_symbols: |
| | | if isinstance(s, CND): |
| | | result.append(s) |
| | | unit_count = 0 |
| | | elif isinstance(s, CNU): |
| | | current_unit = CNU(s.power, None, None, None, None) |
| | | unit_count += 1 |
| | | |
| | | if unit_count == 1: |
| | | result.append(current_unit) |
| | | elif unit_count > 1: |
| | | for i in range(len(result)): |
| | | if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: |
| | | result[-i - 1] = CNU(result[-i - 1].power + |
| | | current_unit.power, None, None, None, None) |
| | | return result |
| | | |
| | | def compute_value(integer_symbols): |
| | | """ |
| | | Compute the value. |
| | | When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. |
| | | e.g. '两千万' = 2000 * 10000 not 2000 + 10000 |
| | | """ |
| | | value = [0] |
| | | last_power = 0 |
| | | for s in integer_symbols: |
| | | if isinstance(s, CND): |
| | | value[-1] = s.value |
| | | elif isinstance(s, CNU): |
| | | value[-1] *= pow(10, s.power) |
| | | if s.power > last_power: |
| | | value[:-1] = list(map(lambda v: v * |
| | | pow(10, s.power), value[:-1])) |
| | | last_power = s.power |
| | | value.append(0) |
| | | return sum(value) |
| | | |
| | | system = create_system(numbering_type) |
| | | int_part, dec_part = string2symbols(chinese_string, system) |
| | | int_part = correct_symbols(int_part, system) |
| | | int_str = str(compute_value(int_part)) |
| | | dec_str = ''.join([str(d.value) for d in dec_part]) |
| | | if dec_part: |
| | | return '{0}.{1}'.format(int_str, dec_str) |
| | | else: |
| | | return int_str |
| | | |
| | | |
| | | def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False, |
| | | traditional=False, alt_zero=False, alt_one=False, alt_two=True, |
| | | use_zeros=True, use_units=True): |
| | | |
| | | def get_value(value_string, use_zeros=True): |
| | | |
| | | striped_string = value_string.lstrip('0') |
| | | |
| | | # record nothing if all zeros |
| | | if not striped_string: |
| | | return [] |
| | | |
| | | # record one digits |
| | | elif len(striped_string) == 1: |
| | | if use_zeros and len(value_string) != len(striped_string): |
| | | return [system.digits[0], system.digits[int(striped_string)]] |
| | | else: |
| | | return [system.digits[int(striped_string)]] |
| | | |
| | | # recursively record multiple digits |
| | | else: |
| | | result_unit = next(u for u in reversed( |
| | | system.units) if u.power < len(striped_string)) |
| | | result_string = value_string[:-result_unit.power] |
| | | return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:]) |
| | | |
| | | system = create_system(numbering_type) |
| | | |
| | | int_dec = number_string.split('.') |
| | | if len(int_dec) == 1: |
| | | int_string = int_dec[0] |
| | | dec_string = "" |
| | | elif len(int_dec) == 2: |
| | | int_string = int_dec[0] |
| | | dec_string = int_dec[1] |
| | | else: |
| | | raise ValueError( |
| | | "invalid input num string with more than one dot: {}".format(number_string)) |
| | | |
| | | if use_units and len(int_string) > 1: |
| | | result_symbols = get_value(int_string) |
| | | else: |
| | | result_symbols = [system.digits[int(c)] for c in int_string] |
| | | dec_symbols = [system.digits[int(c)] for c in dec_string] |
| | | if dec_string: |
| | | result_symbols += [system.math.point] + dec_symbols |
| | | |
| | | if alt_two: |
| | | liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, |
| | | system.digits[2].big_s, system.digits[2].big_t) |
| | | for i, v in enumerate(result_symbols): |
| | | if isinstance(v, CND) and v.value == 2: |
| | | next_symbol = result_symbols[i + |
| | | 1] if i < len(result_symbols) - 1 else None |
| | | previous_symbol = result_symbols[i - 1] if i > 0 else None |
| | | if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): |
| | | if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): |
| | | result_symbols[i] = liang |
| | | |
| | | # if big is True, '两' will not be used and `alt_two` has no impact on output |
| | | if big: |
| | | attr_name = 'big_' |
| | | if traditional: |
| | | attr_name += 't' |
| | | else: |
| | | attr_name += 's' |
| | | else: |
| | | if traditional: |
| | | attr_name = 'traditional' |
| | | else: |
| | | attr_name = 'simplified' |
| | | |
| | | result = ''.join([getattr(s, attr_name) for s in result_symbols]) |
| | | |
| | | # if not use_zeros: |
| | | # result = result.strip(getattr(system.digits[0], attr_name)) |
| | | |
| | | if alt_zero: |
| | | result = result.replace( |
| | | getattr(system.digits[0], attr_name), system.digits[0].alt_s) |
| | | |
| | | if alt_one: |
| | | result = result.replace( |
| | | getattr(system.digits[1], attr_name), system.digits[1].alt_s) |
| | | |
| | | for i, p in enumerate(POINT): |
| | | if result.startswith(p): |
| | | return CHINESE_DIGIS[0] + result |
| | | |
| | | # ^10, 11, .., 19 |
| | | if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], |
| | | SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \ |
| | | result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]: |
| | | result = result[1:] |
| | | |
| | | return result |
| | | |
| | | |
| | | # ================================================================================ # |
| | | # different types of rewriters |
| | | # ================================================================================ # |
| | | class Cardinal: |
| | | """ |
| | | CARDINAL类 |
| | | """ |
| | | |
| | | def __init__(self, cardinal=None, chntext=None): |
| | | self.cardinal = cardinal |
| | | self.chntext = chntext |
| | | |
| | | def chntext2cardinal(self): |
| | | return chn2num(self.chntext) |
| | | |
| | | def cardinal2chntext(self): |
| | | return num2chn(self.cardinal) |
| | | |
| | | class Digit: |
| | | """ |
| | | DIGIT类 |
| | | """ |
| | | |
| | | def __init__(self, digit=None, chntext=None): |
| | | self.digit = digit |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2digit(self): |
| | | # return chn2num(self.chntext) |
| | | |
| | | def digit2chntext(self): |
| | | return num2chn(self.digit, alt_two=False, use_units=False) |
| | | |
| | | |
| | | class TelePhone: |
| | | """ |
| | | TELEPHONE类 |
| | | """ |
| | | |
| | | def __init__(self, telephone=None, raw_chntext=None, chntext=None): |
| | | self.telephone = telephone |
| | | self.raw_chntext = raw_chntext |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2telephone(self): |
| | | # sil_parts = self.raw_chntext.split('<SIL>') |
| | | # self.telephone = '-'.join([ |
| | | # str(chn2num(p)) for p in sil_parts |
| | | # ]) |
| | | # return self.telephone |
| | | |
| | | def telephone2chntext(self, fixed=False): |
| | | |
| | | if fixed: |
| | | sil_parts = self.telephone.split('-') |
| | | self.raw_chntext = '<SIL>'.join([ |
| | | num2chn(part, alt_two=False, use_units=False) for part in sil_parts |
| | | ]) |
| | | self.chntext = self.raw_chntext.replace('<SIL>', '') |
| | | else: |
| | | sp_parts = self.telephone.strip('+').split() |
| | | self.raw_chntext = '<SP>'.join([ |
| | | num2chn(part, alt_two=False, use_units=False) for part in sp_parts |
| | | ]) |
| | | self.chntext = self.raw_chntext.replace('<SP>', '') |
| | | return self.chntext |
| | | |
| | | |
| | | class Fraction: |
| | | """ |
| | | FRACTION类 |
| | | """ |
| | | |
| | | def __init__(self, fraction=None, chntext=None): |
| | | self.fraction = fraction |
| | | self.chntext = chntext |
| | | |
| | | def chntext2fraction(self): |
| | | denominator, numerator = self.chntext.split('分之') |
| | | return chn2num(numerator) + '/' + chn2num(denominator) |
| | | |
| | | def fraction2chntext(self): |
| | | numerator, denominator = self.fraction.split('/') |
| | | return num2chn(denominator) + '分之' + num2chn(numerator) |
| | | |
| | | |
| | | class Date: |
| | | """ |
| | | DATE类 |
| | | """ |
| | | |
| | | def __init__(self, date=None, chntext=None): |
| | | self.date = date |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2date(self): |
| | | # chntext = self.chntext |
| | | # try: |
| | | # year, other = chntext.strip().split('年', maxsplit=1) |
| | | # year = Digit(chntext=year).digit2chntext() + '年' |
| | | # except ValueError: |
| | | # other = chntext |
| | | # year = '' |
| | | # if other: |
| | | # try: |
| | | # month, day = other.strip().split('月', maxsplit=1) |
| | | # month = Cardinal(chntext=month).chntext2cardinal() + '月' |
| | | # except ValueError: |
| | | # day = chntext |
| | | # month = '' |
| | | # if day: |
| | | # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] |
| | | # else: |
| | | # month = '' |
| | | # day = '' |
| | | # date = year + month + day |
| | | # self.date = date |
| | | # return self.date |
| | | |
| | | def date2chntext(self): |
| | | date = self.date |
| | | try: |
| | | year, other = date.strip().split('年', 1) |
| | | year = Digit(digit=year).digit2chntext() + '年' |
| | | except ValueError: |
| | | other = date |
| | | year = '' |
| | | if other: |
| | | try: |
| | | month, day = other.strip().split('月', 1) |
| | | month = Cardinal(cardinal=month).cardinal2chntext() + '月' |
| | | except ValueError: |
| | | day = date |
| | | month = '' |
| | | if day: |
| | | day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] |
| | | else: |
| | | month = '' |
| | | day = '' |
| | | chntext = year + month + day |
| | | self.chntext = chntext |
| | | return self.chntext |
| | | |
| | | |
| | | class Money: |
| | | """ |
| | | MONEY类 |
| | | """ |
| | | |
| | | def __init__(self, money=None, chntext=None): |
| | | self.money = money |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2money(self): |
| | | # return self.money |
| | | |
| | | def money2chntext(self): |
| | | money = self.money |
| | | pattern = re.compile(r'(\d+(\.\d+)?)') |
| | | matchers = pattern.findall(money) |
| | | if matchers: |
| | | for matcher in matchers: |
| | | money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) |
| | | self.chntext = money |
| | | return self.chntext |
| | | |
| | | |
| | | class Percentage: |
| | | """ |
| | | PERCENTAGE类 |
| | | """ |
| | | |
| | | def __init__(self, percentage=None, chntext=None): |
| | | self.percentage = percentage |
| | | self.chntext = chntext |
| | | |
| | | def chntext2percentage(self): |
| | | return chn2num(self.chntext.strip().strip('百分之')) + '%' |
| | | |
| | | def percentage2chntext(self): |
| | | return '百分之' + num2chn(self.percentage.strip().strip('%')) |
| | | |
| | | |
| | | def remove_erhua(text, er_whitelist): |
| | | """ |
| | | 去除儿化音词中的儿: |
| | | 他女儿在那边儿 -> 他女儿在那边 |
| | | """ |
| | | |
| | | er_pattern = re.compile(er_whitelist) |
| | | new_str='' |
| | | while re.search('儿',text): |
| | | a = re.search('儿',text).span() |
| | | remove_er_flag = 0 |
| | | |
| | | if er_pattern.search(text): |
| | | b = er_pattern.search(text).span() |
| | | if b[0] <= a[0]: |
| | | remove_er_flag = 1 |
| | | |
| | | if remove_er_flag == 0 : |
| | | new_str = new_str + text[0:a[0]] |
| | | text = text[a[1]:] |
| | | else: |
| | | new_str = new_str + text[0:b[1]] |
| | | text = text[b[1]:] |
| | | |
| | | text = new_str + text |
| | | return text |
| | | |
| | | # ================================================================================ # |
| | | # NSW Normalizer |
| | | # ================================================================================ # |
| | | class NSWNormalizer: |
| | | def __init__(self, raw_text): |
| | | self.raw_text = '^' + raw_text + '$' |
| | | self.norm_text = '' |
| | | |
| | | def _particular(self): |
| | | text = self.norm_text |
| | | pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('particular') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1) |
| | | self.norm_text = text |
| | | return self.norm_text |
| | | |
| | | def normalize(self): |
| | | text = self.raw_text |
| | | |
| | | # 规范化日期 |
| | | pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('date') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) |
| | | |
| | | # 规范化金钱 |
| | | pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('money') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) |
| | | |
| | | # 规范化固话/手机号码 |
| | | # 手机 |
| | | # http://www.jihaoba.com/news/show/13680 |
| | | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 |
| | | # 联通:130、131、132、156、155、186、185、176 |
| | | # 电信:133、153、189、180、181、177 |
| | | pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('telephone') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) |
| | | # 固话 |
| | | pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('fixed telephone') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) |
| | | |
| | | # 规范化分数 |
| | | pattern = re.compile(r"(\d+/\d+)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('fraction') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) |
| | | |
| | | # 规范化百分数 |
| | | text = text.replace('%', '%') |
| | | pattern = re.compile(r"(\d+(\.\d+)?%)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('percentage') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) |
| | | |
| | | # 规范化纯数+量词 |
| | | pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('cardinal+quantifier') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) |
| | | |
| | | # 规范化数字编号 |
| | | pattern = re.compile(r"(\d{4,32})") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('digit') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) |
| | | |
| | | # 规范化纯数 |
| | | pattern = re.compile(r"(\d+(\.\d+)?)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | #print('cardinal') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) |
| | | |
| | | self.norm_text = text |
| | | self._particular() |
| | | |
| | | return self.norm_text.lstrip('^').rstrip('$') |
| | | |
| | | |
| | | def nsw_test_case(raw_text): |
| | | print('I:' + raw_text) |
| | | print('O:' + NSWNormalizer(raw_text).normalize()) |
| | | print('') |
| | | |
| | | |
| | | def nsw_test(): |
| | | nsw_test_case('固话:0595-23865596或23880880。') |
| | | nsw_test_case('固话:0595-23865596或23880880。') |
| | | nsw_test_case('手机:+86 19859213959或15659451527。') |
| | | nsw_test_case('分数:32477/76391。') |
| | | nsw_test_case('百分数:80.03%。') |
| | | nsw_test_case('编号:31520181154418。') |
| | | nsw_test_case('纯数:2983.07克或12345.60米。') |
| | | nsw_test_case('日期:1999年2月20日或09年3月15号。') |
| | | nsw_test_case('金钱:12块5,34.5元,20.1万') |
| | | nsw_test_case('特殊:O2O或B2C。') |
| | | nsw_test_case('3456万吨') |
| | | nsw_test_case('2938个') |
| | | nsw_test_case('938') |
| | | nsw_test_case('今天吃了115个小笼包231个馒头') |
| | | nsw_test_case('有62%的概率') |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | #nsw_test() |
| | | |
| | | p = argparse.ArgumentParser() |
| | | p.add_argument('ifile', help='input filename, assume utf-8 encoding') |
| | | p.add_argument('ofile', help='output filename') |
| | | p.add_argument('--to_upper', action='store_true', help='convert to upper case') |
| | | p.add_argument('--to_lower', action='store_true', help='convert to lower case') |
| | | p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.") |
| | | p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"') |
| | | p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"') |
| | | p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines') |
| | | args = p.parse_args() |
| | | |
| | | ifile = codecs.open(args.ifile, 'r', 'utf8') |
| | | ofile = codecs.open(args.ofile, 'w+', 'utf8') |
| | | |
| | | n = 0 |
| | | for l in ifile: |
| | | key = '' |
| | | text = '' |
| | | if args.has_key: |
| | | cols = l.split(maxsplit=1) |
| | | key = cols[0] |
| | | if len(cols) == 2: |
| | | text = cols[1].strip() |
| | | else: |
| | | text = '' |
| | | else: |
| | | text = l.strip() |
| | | |
| | | # cases |
| | | if args.to_upper and args.to_lower: |
| | | sys.stderr.write('text norm: to_upper OR to_lower?') |
| | | exit(1) |
| | | if args.to_upper: |
| | | text = text.upper() |
| | | if args.to_lower: |
| | | text = text.lower() |
| | | |
| | | # Filler chars removal |
| | | if args.remove_fillers: |
| | | for ch in FILLER_CHARS: |
| | | text = text.replace(ch, '') |
| | | |
| | | if args.remove_erhua: |
| | | text = remove_erhua(text, ER_WHITELIST) |
| | | |
| | | # NSW(Non-Standard-Word) normalization |
| | | text = NSWNormalizer(text).normalize() |
| | | |
| | | # Punctuations removal |
| | | old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations |
| | | new_chars = ' ' * len(old_chars) |
| | | del_chars = '' |
| | | text = text.translate(str.maketrans(old_chars, new_chars, del_chars)) |
| | | |
| | | # |
| | | if args.has_key: |
| | | ofile.write(key + '\t' + text + '\n') |
| | | else: |
| | | ofile.write(text + '\n') |
| | | |
| | | n += 1 |
| | | if n % args.log_interval == 0: |
| | | sys.stderr.write("text norm: {} lines done.\n".format(n)) |
| | | |
| | | sys.stderr.write("text norm: {} lines done in total.\n".format(n)) |
| | | |
| | | ifile.close() |
| | | ofile.close() |
| | |
| | | #local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch |
| | | #git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path} |
| | | |
| | | |
| | | ## generate jsonl from wav.scp and text.txt |
| | | #python funasr/datasets/audio_datasets/scp2jsonl.py \ |
| | | #++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ |
| | | #++data_type_list='["source", "target"]' \ |
| | | #++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl |
| | | # torchrun \ |
| | | # --nnodes 1 \ |
| | | # --nproc_per_node 1 \ |
| | | python funasr/bin/train.py \ |
| | | +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ |
| | | +model_revision="v2.0.4" \ |
| New file |
| | |
| | | import os |
| | | import json |
| | | import numpy as np |
| | | import torch |
| | | import hydra |
| | | import logging |
| | | from omegaconf import DictConfig, OmegaConf |
| | | |
| | | from funasr.register import tables |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | | if kwargs.get("debug", False): |
| | | import pdb; pdb.set_trace() |
| | | |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) |
| | | kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) |
| | | |
| | | |
| | | main(**kwargs) |
| | | |
| | | |
| | | def main(**kwargs): |
| | | print(kwargs) |
| | | # set random seed |
| | | tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | |
| | | |
| | | |
| | | |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | |
| | | # build frontend if frontend is none None |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | | |
| | | |
| | | |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_train = None |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | dataset_conf = kwargs.get("dataset_conf") |
| | | dataset_conf["batch_type"] = "example" |
| | | dataset_conf["batch_size"] = 1 |
| | | batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf) |
| | | |
| | | |
| | | dataloader_train = torch.utils.data.DataLoader(dataset_train, |
| | | collate_fn=dataset_train.collator, |
| | | batch_sampler=batch_sampler_train, |
| | | num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)), |
| | | pin_memory=True) |
| | | |
| | | iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train)) |
| | | |
| | | total_frames = 0 |
| | | for batch_idx, batch in enumerate(dataloader_train): |
| | | if batch_idx >= iter_stop: |
| | | break |
| | | |
| | | fbank = batch["speech"].numpy()[0, :, :] |
| | | if total_frames == 0: |
| | | mean_stats = fbank |
| | | var_stats = np.square(fbank) |
| | | else: |
| | | mean_stats += np.sum(fbank, axis=0) |
| | | var_stats += np.sum(np.square(fbank), axis=0) |
| | | total_frames += fbank.shape[0] |
| | | |
| | | |
| | | cmvn_info = { |
| | | 'mean_stats': list(mean_stats.tolist()), |
| | | 'var_stats': list(var_stats.tolist()), |
| | | 'total_frames': total_frames |
| | | } |
| | | cmvn_file = kwargs.get("cmvn_file", "cmvn.json") |
| | | with open(cmvn_file, 'w') as fout: |
| | | fout.write(json.dumps(cmvn_info)) |
| | | |
| | | mean = -1.0 * mean_stats / total_frames |
| | | var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean) |
| | | dims = mean.shape[0] |
| | | am_mvn = os.path.dirname(cmvn_file) + "/am.mvn" |
| | | with open(am_mvn, 'w') as fout: |
| | | fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n") |
| | | mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]') |
| | | fout.write("<LearnRateCoef> 0 " + mean_str + '\n') |
| | | fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n') |
| | | var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]') |
| | | fout.write("<LearnRateCoef> 0 " + var_str + '\n') |
| | | fout.write("</Nnet>" + '\n') |
| | | |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main_hydra() |
| | | """ |
| | | python funasr/bin/compute_status.py \ |
| | | --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \ |
| | | --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ |
| | | ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \ |
| | | ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \ |
| | | ++dataset_conf.num_workers=32 |
| | | """ |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import os |
| | | import sys |
| | | import torch |
| | |
| | | |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, |
| | | **kwargs.get("dataset_conf")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf")) |
| | | dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | |
| | | **kwargs): |
| | | super().__init__() |
| | | index_ds_class = tables.index_ds_classes.get(index_ds) |
| | | self.index_ds = index_ds_class(path) |
| | | self.index_ds = index_ds_class(path, **kwargs) |
| | | preprocessor_speech = kwargs.get("preprocessor_speech", None) |
| | | if preprocessor_speech: |
| | | preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech) |
| | |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | | target = self.preprocessor_text(target) |
| | | ids = self.tokenizer.encode(target) |
| | | if self.tokenizer: |
| | | ids = self.tokenizer.encode(target) |
| | | text = torch.tensor(ids, dtype=torch.int64) |
| | | else: |
| | | ids = target |
| | | text = ids |
| | | ids_lengths = len(ids) |
| | | text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32) |
| | | text_lengths = torch.tensor([ids_lengths], dtype=torch.int32) |
| | | |
| | | return {"speech": speech[0, :, :], |
| | | "speech_lengths": speech_lengths, |
| | |
| | | outputs[key].append(sample[key]) |
| | | |
| | | for key, data_list in outputs.items(): |
| | | if data_list[0].dtype == torch.int64: |
| | | |
| | | pad_value = self.int_pad_value |
| | | else: |
| | | pad_value = self.float_pad_value |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) |
| | | if isinstance(data_list[0], torch.Tensor): |
| | | if data_list[0].dtype == torch.int64: |
| | | |
| | | pad_value = self.int_pad_value |
| | | else: |
| | | pad_value = self.float_pad_value |
| | | |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) |
| | | return outputs |
| | | |
| | |
| | | import os |
| | | import json |
| | | import torch |
| | | import logging |
| | | import concurrent.futures |
| | | import librosa |
| | | import torch.distributed as dist |
| | | |
| | | from funasr.register import tables |
| | |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankFull") |
| | | class IndexDSJsonlRankFull(torch.utils.data.Dataset): |
| | | |
| | | def __init__(self, path): |
| | | def __init__(self, path: str, **kwargs): |
| | | super().__init__() |
| | | |
| | | if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans |
| | | from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list |
| | | jsonl_outdir = os.path.dirname(path[0]) |
| | | jsonl_name = "datalist_train.jsonl" if kwargs.get("is_training", True) else "datalist_val.jsonl" |
| | | jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name) |
| | | if not os.path.exists(jsonl_file_out): |
| | | print(f"datalist is: {path}, generate jsonl from it") |
| | | gen_jsonl_from_wav_text_list(path, jsonl_file_out=jsonl_file_out, **kwargs) |
| | | path = jsonl_file_out |
| | | |
| | | contents = [] |
| | | with open(path, encoding='utf-8') as fin: |
| | | for line in fin: |
| New file |
| | |
| | | import os |
| | | import json |
| | | import torch |
| | | import logging |
| | | import hydra |
| | | from omegaconf import DictConfig, OmegaConf |
| | | import concurrent.futures |
| | | import librosa |
| | | import torch.distributed as dist |
| | | |
| | | |
| | | |
| | | def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs): |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | |
| | | cpu_cores = os.cpu_count() or 1 |
| | | |
| | | if rank == 0: |
| | | json_dict = {} |
| | | for data_type, data_file in zip(data_type_list, path): |
| | | json_dict[data_type] = {} |
| | | with open(data_file, "r") as f: |
| | | |
| | | data_file_lists = f.readlines() |
| | | lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1 |
| | | task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 |
| | | with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor: |
| | | |
| | | futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)] |
| | | |
| | | for future in concurrent.futures.as_completed(futures): |
| | | |
| | | json_dict[data_type].update(future.result()) |
| | | # print(json_dict) |
| | | |
| | | with open(jsonl_file_out, "w") as f: |
| | | for key in json_dict[data_type_list[0]].keys(): |
| | | jsonl_line = {"key": key} |
| | | for data_file in data_type_list: |
| | | jsonl_line.update(json_dict[data_file][key]) |
| | | jsonl_line = json.dumps(jsonl_line, ensure_ascii=False) |
| | | f.write(jsonl_line+"\n") |
| | | f.flush() |
| | | |
| | | else: |
| | | pass |
| | | |
| | | if world_size > 1: |
| | | dist.barrier() |
| | | |
| | | |
| | | def parse_context_length(data_list: list, data_type: str): |
| | | |
| | | res = {} |
| | | for i, line in enumerate(data_list): |
| | | key, line = line.strip().split(maxsplit=1) |
| | | line = line.strip() |
| | | if os.path.exists(line): |
| | | waveform, _ = librosa.load(line, sr=16000) |
| | | sample_num = len(waveform) |
| | | context_len = int(sample_num//16000*1000/10) |
| | | else: |
| | | context_len = len(line) |
| | | res[key] = {data_type: line, f"{data_type}_len": context_len} |
| | | return res |
| | | |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(cfg: DictConfig): |
| | | """ |
| | | python funasr/datasets/audio_datasets/scp2jsonl.py \ |
| | | ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ |
| | | ++data_type_list='["source", "target"]' \ |
| | | ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl |
| | | |
| | | """ |
| | | |
| | | kwargs = OmegaConf.to_container(cfg, resolve=True) |
| | | |
| | | scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt")) |
| | | data_type_list = kwargs.get("data_type_list", ("source", "target")) |
| | | jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl") |
| | | gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main_hydra() |
| | | |
| | | |
| New file |
| | |
| | | |
| | | def download_dataset(): |
| | | pass |
| | | |
| | | def download_dataset_from_ms(**kwargs): |
| | | from modelscope.msdatasets import MsDataset |
| | | dataset_name = kwargs.get("dataset_name", 'speech_asr/speech_asr_aishell1_trainsets') |
| | | subset_name = kwargs.get("subset_name", 'default') |
| | | split = kwargs.get("split", 'train') |
| | | data_dump_dir = kwargs.get("data_dump_dir", None) |
| | | ds = MsDataset.load(dataset_name=dataset_name, subset_name=subset_name, split=split, cache_dir=data_dump_dir) |
| | |
| | | registry = getattr(self, register_tables_key) |
| | | registry_key = key if key is not None else target_class.__name__ |
| | | |
| | | assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format( |
| | | registry_key, target_class, register_tables_key) |
| | | # assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format( |
| | | # registry_key, target_class, register_tables_key) |
| | | |
| | | registry[registry_key] = target_class |
| | | |
| | |
| | | my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext |
| | | with my_context(): |
| | | time2 = time.perf_counter() |
| | | print("before, GPU, memory: {:.1} MB, " |
| | | "{:.1} MB, " |
| | | "{:.1} MB, " |
| | | "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, |
| | | torch.cuda.max_memory_allocated()/1024/1024/1024, |
| | | torch.cuda.memory_reserved()/1024/1024/1024, |
| | | torch.cuda.max_memory_reserved()/1024/1024/1024, |
| | | )) |
| | | # print("before, GPU, memory: {:.3f} GB, " |
| | | # "{:.3f} GB, " |
| | | # "{:.3f} GB, " |
| | | # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024, |
| | | # torch.cuda.max_memory_allocated()/1024/1024/1024, |
| | | # torch.cuda.memory_reserved()/1024/1024/1024, |
| | | # torch.cuda.max_memory_reserved()/1024/1024/1024, |
| | | # )) |
| | | |
| | | retval = self.model(**batch) |
| | | torch.cuda.empty_cache() |
| | | print("after, GPU, memory: {:.1} MB, " |
| | | "{:.1} MB, " |
| | | "{:.1} MB, " |
| | | "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, |
| | | torch.cuda.max_memory_allocated()/1024/1024/1024, |
| | | torch.cuda.memory_reserved()/1024/1024/1024, |
| | | torch.cuda.max_memory_reserved()/1024/1024/1024, |
| | | )) |
| | | # print("after, GPU, memory: {:.3f} GB, " |
| | | # "{:.3f} GB, " |
| | | # "{:.3f} GB, " |
| | | # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024, |
| | | # torch.cuda.max_memory_allocated()/1024/1024/1024, |
| | | # torch.cuda.memory_reserved()/1024/1024/1024, |
| | | # torch.cuda.max_memory_reserved()/1024/1024/1024, |
| | | # )) |
| | | time3 = time.perf_counter() |
| | | speed_stats["forward_time"] = f"{time3 - time2:0.3f}" |
| | | loss, stats, weight = retval |
| | |
| | | |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | gpu_info = "GPU, memory: {:.3f} GB, " \ |
| | | "{:.3f} GB, "\ |
| | | "{:.3f} GB, "\ |
| | | "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024, |
| | | torch.cuda.max_memory_allocated()/1024/1024/1024, |
| | | torch.cuda.memory_reserved()/1024/1024/1024, |
| | | torch.cuda.max_memory_reserved()/1024/1024/1024, |
| | | ) |
| | | description = ( |
| | | f"Train epoch: {epoch}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" |
| | | f"{gpu_info}" |
| | | ) |
| | | pbar.set_description(description) |
| | | if self.writer: |